From 0e7bd9fd03ef5fc39eb4dab9e4c698ad06a12d2a Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 21 Jun 2025 16:18:33 +0300 Subject: [PATCH] factor out generic MemoryManager (#10910) * allocator -> memory * just moveout it * mm is abstracted * need entry abstraction * fix * mypy --- test/external/external_fuzz_tlsf.py | 2 +- test/external/external_test_am.py | 11 +- test/external/external_test_tlsf.py | 2 +- tinygrad/engine/memory.py | 2 +- tinygrad/runtime/ops_amd.py | 5 +- tinygrad/runtime/support/allocator.py | 97 ---------- tinygrad/runtime/support/am/amdev.py | 153 ++-------------- tinygrad/runtime/support/memory.py | 248 ++++++++++++++++++++++++++ 8 files changed, 274 insertions(+), 246 deletions(-) delete mode 100644 tinygrad/runtime/support/allocator.py create mode 100644 tinygrad/runtime/support/memory.py diff --git a/test/external/external_fuzz_tlsf.py b/test/external/external_fuzz_tlsf.py index 512ced0b59..66eb85ca90 100644 --- a/test/external/external_fuzz_tlsf.py +++ b/test/external/external_fuzz_tlsf.py @@ -1,7 +1,7 @@ import random from typing import Dict, Optional from tinygrad.helpers import getenv -from tinygrad.runtime.support.allocator import TLSFAllocator +from tinygrad.runtime.support.memory import TLSFAllocator class AllocatorFuzzer: def __init__(self, total_size): diff --git a/test/external/external_test_am.py b/test/external/external_test_am.py index dfe241fc6f..03371129f6 100644 --- a/test/external/external_test_am.py +++ b/test/external/external_test_am.py @@ -1,7 +1,8 @@ import unittest -from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraverseContext +from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableEntry from tinygrad.runtime.support.am.ip import AM_GMC from tinygrad.runtime.support.hcq import MMIOInterface +from tinygrad.runtime.support.memory import PageTableTraverseContext from tinygrad.runtime.autogen.am import am from tinygrad.helpers import mv_address @@ -23,7 +24,9 @@ class FakeAM: self.vram_mv = memoryview(bytearray(4 << 30)) self.vram = MMIOInterface(mv_address(self.vram_mv), self.vram_mv.nbytes) self.gmc = FakeGMC(self) - self.mm = AMMemoryManager(self, vram_size=4 << 30) + self.mm = AMMemoryManager(self, 4 << 30, boot_size=(32 << 20), pt_t=AMPageTableEntry, pte_cnt=[512, 512, 512, 512], + pte_covers=[(1 << ((9 * (3-lv)) + 12)) for lv in range(4)], first_lv=am.AMDGPU_VM_PDB1, first_page_lv=am.AMDGPU_VM_PDB2, + va_base=AMMemoryManager.va_allocator.base) self.is_booting = False self.ip_ver = {am.GC_HWIP: (11, 0, 0)} def paddr2cpu(self, paddr:int) -> int: return paddr + mv_address(self.vram) @@ -65,7 +68,7 @@ class TestAMPageTable(unittest.TestCase): exteranl_va = va + AMMemoryManager.va_allocator.base mm.map_range(vaddr=exteranl_va, size=sz, paddrs=[(va, sz)]) - ctx = AMPageTableTraverseContext(self.d[0], mm.root_page_table, exteranl_va) + ctx = PageTableTraverseContext(self.d[0], mm.root_page_table, exteranl_va) results = list(ctx.next(sz)) total_covered = 0 @@ -126,7 +129,7 @@ class TestAMPageTable(unittest.TestCase): # Finally can map and check paddrs mm0.map_range(vaddr=exteranl_va + 0x2000, size=0x100000, paddrs=[(0xdead0000, 0x1000), (0xdead1000, 0xff000)]) - ctx = AMPageTableTraverseContext(self.d[0], mm0.root_page_table, exteranl_va + 0x2000) + ctx = PageTableTraverseContext(self.d[0], mm0.root_page_table, exteranl_va + 0x2000) for tup in ctx.next(0x100000): _offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup for i in range(_n_ptes): diff --git a/test/external/external_test_tlsf.py b/test/external/external_test_tlsf.py index aa780fdd50..6db41f0389 100644 --- a/test/external/external_test_tlsf.py +++ b/test/external/external_test_tlsf.py @@ -1,5 +1,5 @@ import unittest -from tinygrad.runtime.support.allocator import TLSFAllocator +from tinygrad.runtime.support.memory import TLSFAllocator class TestTLSFAllocator(unittest.TestCase): def setUp(self): diff --git a/tinygrad/engine/memory.py b/tinygrad/engine/memory.py index 96fd6a0678..2c62b3f6a1 100644 --- a/tinygrad/engine/memory.py +++ b/tinygrad/engine/memory.py @@ -5,7 +5,7 @@ from tinygrad.device import Device, Buffer from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up from tinygrad.uop.ops import Ops from tinygrad.dtype import dtypes, ImageDType -from tinygrad.runtime.support.allocator import TLSFAllocator +from tinygrad.runtime.support.memory import TLSFAllocator # **************** memory planning **************** diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 41c30e7b88..9f07dee4c3 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -14,9 +14,10 @@ from tinygrad.runtime.autogen import kfd, hsa, pci, sqtt from tinygrad.runtime.autogen.am import am from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler from tinygrad.runtime.support.elf import elf_loader -from tinygrad.runtime.support.am.amdev import AMDev, AMMapping +from tinygrad.runtime.support.am.amdev import AMDev from tinygrad.runtime.support.amd import AMDReg, AMDIP, import_module, setup_pci_bars from tinygrad.runtime.support.system import System, PCIDevice, MAP_FIXED, MAP_NORESERVE +from tinygrad.runtime.support.memory import VirtMapping from tinygrad.runtime.support.usb import ASM24Controller, USBMMIOInterface if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import @@ -635,7 +636,7 @@ class KFDIface: raise RuntimeError("\n".join(report)) @dataclass -class AMAllocationMeta: owner:AMDDevice; mapped_devs:list[AMDDevice]; mapping:AMMapping; has_cpu_mapping:bool # noqa: E702 +class AMAllocationMeta: owner:AMDDevice; mapped_devs:list[AMDDevice]; mapping:VirtMapping; has_cpu_mapping:bool # noqa: E702 class PCIIface: gpus:list[Any] = [] diff --git a/tinygrad/runtime/support/allocator.py b/tinygrad/runtime/support/allocator.py deleted file mode 100644 index 9060716fbb..0000000000 --- a/tinygrad/runtime/support/allocator.py +++ /dev/null @@ -1,97 +0,0 @@ -import collections, functools -from tinygrad.helpers import round_up - -class TLSFAllocator: - """ - The allocator is based on the Two-Level Segregated Fit (TLSF) algorithm. The allocator maintains 2 level of buckets: - * 1st level is determined by the most significant bit of the size. - * 2nd level splits the covered memory of 1st level into @lv2_cnt entries. - - For each allocation request, the allocator searches for the smallest block that can fit the requested size. - For each deallocation request, the allocator merges the block with its neighbors if they are free. - """ - - def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16): - self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length() - self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)] - self.lv1_entries:list[int] = [0] * len(self.storage) - - # self.blocks is more like a linked list, where each entry is a contiguous block. - self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free - self._insert_block(0, size) - - @functools.cache - def lv1(self, size): return size.bit_length() - - @functools.cache - def lv2(self, size): return (size - (1 << (size.bit_length() - 1))) // (1 << max(0, size.bit_length() - self.l2_cnt)) - - def _insert_block(self, start:int, size:int, prev:int|None=None): - if prev is None: prev = self.blocks[start][2] - self.storage[self.lv1(size)][self.lv2(size)].append(start) - self.lv1_entries[self.lv1(size)] += 1 - self.blocks[start] = (size, start + size, prev, True) - return self - - def _remove_block(self, start:int, size:int, prev:int|None=None): - if prev is None: prev = self.blocks[start][2] - self.storage[self.lv1(size)][self.lv2(size)].remove(start) - self.lv1_entries[self.lv1(size)] -= 1 - self.blocks[start] = (size, start + size, prev, False) - return self - - def _split_block(self, start:int, size:int, new_size:int): - nxt = self.blocks[start][1] - assert self.blocks[start][3], "block must be free" - self._remove_block(start, size)._insert_block(start, new_size)._insert_block(start + new_size, size - new_size, prev=start) - if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start + new_size, self.blocks[nxt][3]) - return self - - def _merge_right(self, start:int): - size, nxt, _, is_free = self.blocks[start] - assert is_free, "block must be free" - - while is_free and nxt in self.blocks: - if (blk:=self.blocks[nxt])[3] is False: break - self._remove_block(start, size)._remove_block(nxt, blk[0])._insert_block(start, size:=size + blk[0]) - assert self.blocks[start][1] == blk[1] - _, nxt, _, _ = self.blocks.pop(nxt) - - if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start, self.blocks[nxt][3]) - - def _merge_block(self, start:int): - # Go left while blocks are free. Then merge all them right. - while (x:=self.blocks[start][2]) is not None and self.blocks[x][3] is True: start = x - self._merge_right(start) - - def alloc(self, req_size:int, align:int=1) -> int: - req_size = max(self.block_size, req_size) # at least block size. - size = max(self.block_size, req_size + align - 1) - - # Round up the allocation size to the next bucket, so any entry there can fit the requested size. - size = round_up(size, (1 << size.bit_length() - self.l2_cnt)) - - # Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found. - for l1 in range(self.lv1(size), len(self.storage)): - if self.lv1_entries[l1] == 0: continue - for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)): - if len(self.storage[l1][l2]) > 0: - nsize = self.blocks[self.storage[l1][l2][0]][0] - assert nsize >= size, "block must be larger" - - # Block start address. - start = self.storage[l1][l2][0] - - # If request contains alignment, split the block into two parts. - if (new_start:=round_up(start, align)) != start: - self._split_block(start, nsize, new_start - start) - start, nsize = new_start, self.blocks[new_start][0] - - # If the block is larger than the requested size, split it into two parts. - if nsize > req_size: self._split_block(start, nsize, req_size) - self._remove_block(start, req_size) # Mark the block as allocated. - return start + self.base - raise MemoryError(f"Can't allocate {req_size} bytes") - - def free(self, start:int): - self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base) diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 298a0adfe7..0fb537d96b 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -1,10 +1,10 @@ from __future__ import annotations import ctypes, collections, time, dataclasses, functools, fcntl, os, hashlib -from tinygrad.helpers import mv_address, getenv, round_up, DEBUG, temp, fetch +from tinygrad.helpers import mv_address, getenv, DEBUG, temp, fetch from tinygrad.runtime.autogen.am import am from tinygrad.runtime.support.hcq import MMIOInterface from tinygrad.runtime.support.amd import AMDReg, import_module, import_asic_regs -from tinygrad.runtime.support.allocator import TLSFAllocator +from tinygrad.runtime.support.memory import TLSFAllocator, MemoryManager from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA AM_DEBUG = getenv("AM_DEBUG", 0) @@ -90,9 +90,6 @@ class AMFirmware: def desc(self, blob:memoryview, offset:int, size:int, *types:int) -> tuple[list[int], memoryview]: return (list(types), blob[offset:offset+size]) -@dataclasses.dataclass(frozen=True) -class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702 - class AMPageTableEntry: def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.lv, self.entries = adev, paddr, lv, adev.vram.view(paddr, 0x1000, fmt='Q') @@ -100,144 +97,18 @@ class AMPageTableEntry: assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}" self.entries[entry_id] = self.adev.gmc.get_pte_flags(self.lv, table, frag, uncached, system, snooped, valid) | (paddr & 0x0000FFFFFFFFF000) -class AMPageTableTraverseContext: - def __init__(self, adev, pt, vaddr, create_pts=False, free_pts=False, boot=False): - self.adev, self.vaddr, self.create_pts, self.free_pts, self.boot = adev, vaddr - adev.gmc.vm_base, create_pts, free_pts, boot - self.pt_stack:list[tuple[AMPageTableEntry, int, int]] = [(pt, self._pt_pte_idx(pt, vaddr), self._pt_pte_size(pt))] + def entry(self, entry_id:int) -> int: return self.entries[entry_id] + def valid(self, entry_id:int) -> bool: return (self.entries[entry_id] & am.AMDGPU_PTE_VALID) != 0 + def address(self, entry_id:int) -> int: return self.entries[entry_id] & 0x0000FFFFFFFFF000 + def is_pte(self, entry_id:int) -> bool: return self.lv == am.AMDGPU_VM_PTB or self.adev.gmc.is_pte_huge_page(self.entries[entry_id]) - def _pt_pte_size(self, pt): return (1 << ((9 * (3-pt.lv)) + 12)) - def _pt_pte_idx(self, pt, va): return (va // self._pt_pte_size(pt)) % 512 - - def level_down(self): - pt, pte_idx, _ = self.pt_stack[-1] - if (entry:=pt.entries[pte_idx]) & am.AMDGPU_PTE_VALID == 0: - assert self.create_pts, "Not allowed to create new page table" - pt.set_entry(pte_idx, self.adev.mm.palloc(0x1000, zero=True, boot=self.boot), table=True, valid=True) - entry = pt.entries[pte_idx] - - assert not self.adev.gmc.is_pte_huge_page(entry), f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}" - child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1) - - self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table))) - return self.pt_stack[-1] - - def _try_free_pt(self) -> bool: - pt, _, _ = self.pt_stack[-1] - if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.entries[i] & am.AMDGPU_PTE_VALID == 0 for i in range(512)): - self.adev.mm.pfree(pt.paddr) - parent_pt, parent_pte_idx, _ = self.pt_stack[-2] - parent_pt.set_entry(parent_pte_idx, 0x0, valid=False) - return True - return False - - def level_up(self): - while self._try_free_pt() or self.pt_stack[-1][1] == 512: - _, pt_cnt, _ = self.pt_stack.pop() - if pt_cnt == 512: self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2]) - - def next(self, size:int, off=0): - while size > 0: - pt, pte_idx, pte_covers = self.pt_stack[-1] - if self.create_pts: - while pte_covers > size or self.vaddr & (pte_covers-1) != 0: pt, pte_idx, pte_covers = self.level_down() - else: - while pt.lv!=am.AMDGPU_VM_PTB and not self.adev.gmc.is_pte_huge_page(pt.entries[pte_idx]): pt, pte_idx, pte_covers = self.level_down() - - entries = min(size // pte_covers, 512 - pte_idx) - assert entries > 0, "Invalid entries" - yield off, pt, pte_idx, entries, pte_covers - - size, off, self.vaddr = size - entries * pte_covers, off + entries * pte_covers, self.vaddr + entries * pte_covers - self.pt_stack[-1] = (pt, pte_idx + entries, pte_covers) - self.level_up() - -class AMMemoryManager: +class AMMemoryManager(MemoryManager): va_allocator = TLSFAllocator(512 * (1 << 30), base=0x200000000000) # global for all devices. - def __init__(self, adev:AMDev, vram_size:int): - self.adev, self.vram_size = adev, vram_size - self.boot_allocator = TLSFAllocator(32 << 20, base=0) # per device - self.pa_allocator = TLSFAllocator(vram_size - (64 << 20), base=self.boot_allocator.size) # per device - self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1) - - def _frag_size(self, va, sz, must_cover=True): - """ - Calculate the tlb fragment size for a given virtual address and size. - If must_cover is True, the fragment size must cover the size, otherwise the biggest fragment size that fits the size is returned. - Fragment 0 is 4KB, 1 is 8KB and so on. - """ - va_pwr2_div, sz_pwr2_div, sz_pwr2_max = va & -(va) if va > 0 else (1 << 63), sz & -(sz), (1 << (sz.bit_length() - 1)) - return (min(va_pwr2_div, sz_pwr2_div) if must_cover else min(va_pwr2_div, sz_pwr2_max)).bit_length() - 1 - 12 - - def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False, boot=False) -> AMMapping: - if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: mapping {vaddr=:#x} ({size=:#x})") - - assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}" - - ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, create_pts=True, boot=boot) - for paddr, psize in paddrs: - for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize): - for pte_off in range(pte_cnt): - assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}" - pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped, - frag=self._frag_size(ctx.vaddr+off, pte_cnt * pte_covers), valid=True) - + def on_range_mapped(self): # Invalidate TLB after mappings. - self.adev.gmc.flush_tlb(ip='GC', vmid=0) - self.adev.gmc.flush_tlb(ip='MM', vmid=0) - return AMMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped) - - def unmap_range(self, vaddr:int, size:int): - if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})") - - ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, free_pts=True) - for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size): - for pte_id in range(pte_idx, pte_idx + pte_cnt): - assert pt.entries[pte_id] & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.entries[pte_id]:#x}" - pt.set_entry(pte_id, paddr=0x0, valid=False) - - @staticmethod - def alloc_vaddr(size:int, align=0x1000) -> int: return AMMemoryManager.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align)) - - def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> AMMapping: - # Alloc physical memory and map it to the virtual address - va = self.alloc_vaddr(size:=round_up(size, 0x1000), align) - - if contiguous: paddrs = [(self.palloc(size, zero=True), size)] - else: - # Traverse the PT to find the largest contiguous sizes we need to allocate. Try to allocate the longest segment to reduce TLB pressure. - paddrs = [] - ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True) - for off, _, _, seg_cnt, seg_size in ctx.next(size): - rem_len = seg_cnt * seg_size - while rem_len > 0: - # Try to allocate as long segment (power of 2) as possible - cont_seg_sz, paddr = 1 << (self._frag_size(ctx.vaddr+off, rem_len) + 12), None - while cont_seg_sz >= 0x1000: - try: paddr = self.palloc(cont_seg_sz, zero=False) - except MemoryError: cont_seg_sz //= 2 - else: break - - if paddr is not None: paddrs += [(paddr, cont_seg_sz)] - else: - for paddr, _ in paddrs: self.pa_allocator.free(paddr) - raise MemoryError(f"Failed to allocate a contiguous page. (allocation size={size:#x})") - rem_len, off = rem_len - cont_seg_sz, off + cont_seg_sz - - return self.map_range(va, size, paddrs, uncached=uncached) - - def vfree(self, vm:AMMapping): - self.unmap_range(vm.va_addr, vm.size) - self.va_allocator.free(vm.va_addr) - for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr) - - def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int: - assert self.adev.is_booting == boot, "During booting, only boot memory can be allocated" - paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align) - if zero: self.adev.vram[paddr:paddr+size] = bytes(size) - return paddr - - def pfree(self, paddr:int): self.pa_allocator.free(paddr) + self.dev.gmc.flush_tlb(ip='GC', vmid=0) + self.dev.gmc.flush_tlb(ip='MM', vmid=0) class AMDev: def __init__(self, devfmt, vram:MMIOInterface, doorbell:MMIOInterface, mmio:MMIOInterface, dma_regions:list[tuple[int, MMIOInterface]]|None=None): @@ -268,7 +139,9 @@ class AMDev: self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000005)) and (getenv("AM_RESET", 0) != 1) # Memory manager & firmware - self.mm = AMMemoryManager(self, self.vram_size) + self.mm = AMMemoryManager(self, self.vram_size, boot_size=(32 << 20), pt_t=AMPageTableEntry, pte_cnt=[512, 512, 512, 512], + pte_covers=[(1 << ((9 * (3-lv)) + 12)) for lv in range(4)], first_lv=am.AMDGPU_VM_PDB1, first_page_lv=am.AMDGPU_VM_PDB2, + va_base=AMMemoryManager.va_allocator.base) self.fw = AMFirmware(self) # Initialize IP blocks diff --git a/tinygrad/runtime/support/memory.py b/tinygrad/runtime/support/memory.py new file mode 100644 index 0000000000..62e11ec9e7 --- /dev/null +++ b/tinygrad/runtime/support/memory.py @@ -0,0 +1,248 @@ +import collections, functools, dataclasses +from typing import Any, ClassVar +from tinygrad.helpers import round_up, getenv + +class TLSFAllocator: + """ + The allocator is based on the Two-Level Segregated Fit (TLSF) algorithm. The allocator maintains 2 level of buckets: + * 1st level is determined by the most significant bit of the size. + * 2nd level splits the covered memory of 1st level into @lv2_cnt entries. + + For each allocation request, the allocator searches for the smallest block that can fit the requested size. + For each deallocation request, the allocator merges the block with its neighbors if they are free. + """ + + def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16): + self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length() + self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)] + self.lv1_entries:list[int] = [0] * len(self.storage) + + # self.blocks is more like a linked list, where each entry is a contiguous block. + self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free + self._insert_block(0, size) + + @functools.cache + def lv1(self, size): return size.bit_length() + + @functools.cache + def lv2(self, size): return (size - (1 << (size.bit_length() - 1))) // (1 << max(0, size.bit_length() - self.l2_cnt)) + + def _insert_block(self, start:int, size:int, prev:int|None=None): + if prev is None: prev = self.blocks[start][2] + self.storage[self.lv1(size)][self.lv2(size)].append(start) + self.lv1_entries[self.lv1(size)] += 1 + self.blocks[start] = (size, start + size, prev, True) + return self + + def _remove_block(self, start:int, size:int, prev:int|None=None): + if prev is None: prev = self.blocks[start][2] + self.storage[self.lv1(size)][self.lv2(size)].remove(start) + self.lv1_entries[self.lv1(size)] -= 1 + self.blocks[start] = (size, start + size, prev, False) + return self + + def _split_block(self, start:int, size:int, new_size:int): + nxt = self.blocks[start][1] + assert self.blocks[start][3], "block must be free" + self._remove_block(start, size)._insert_block(start, new_size)._insert_block(start + new_size, size - new_size, prev=start) + if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start + new_size, self.blocks[nxt][3]) + return self + + def _merge_right(self, start:int): + size, nxt, _, is_free = self.blocks[start] + assert is_free, "block must be free" + + while is_free and nxt in self.blocks: + if (blk:=self.blocks[nxt])[3] is False: break + self._remove_block(start, size)._remove_block(nxt, blk[0])._insert_block(start, size:=size + blk[0]) + assert self.blocks[start][1] == blk[1] + _, nxt, _, _ = self.blocks.pop(nxt) + + if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start, self.blocks[nxt][3]) + + def _merge_block(self, start:int): + # Go left while blocks are free. Then merge all them right. + while (x:=self.blocks[start][2]) is not None and self.blocks[x][3] is True: start = x + self._merge_right(start) + + def alloc(self, req_size:int, align:int=1) -> int: + req_size = max(self.block_size, req_size) # at least block size. + size = max(self.block_size, req_size + align - 1) + + # Round up the allocation size to the next bucket, so any entry there can fit the requested size. + size = round_up(size, (1 << size.bit_length() - self.l2_cnt)) + + # Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found. + for l1 in range(self.lv1(size), len(self.storage)): + if self.lv1_entries[l1] == 0: continue + for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)): + if len(self.storage[l1][l2]) > 0: + nsize = self.blocks[self.storage[l1][l2][0]][0] + assert nsize >= size, "block must be larger" + + # Block start address. + start = self.storage[l1][l2][0] + + # If request contains alignment, split the block into two parts. + if (new_start:=round_up(start, align)) != start: + self._split_block(start, nsize, new_start - start) + start, nsize = new_start, self.blocks[new_start][0] + + # If the block is larger than the requested size, split it into two parts. + if nsize > req_size: self._split_block(start, nsize, req_size) + self._remove_block(start, req_size) # Mark the block as allocated. + return start + self.base + raise MemoryError(f"Can't allocate {req_size} bytes") + + def free(self, start:int): + self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base) + +# Memory Managment + +@dataclasses.dataclass(frozen=True) +class VirtMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702 + +class PageTableTraverseContext: + def __init__(self, dev, pt, vaddr, create_pts=False, free_pts=False, boot=False): + self.dev, self.vaddr, self.create_pts, self.free_pts, self.boot = dev, vaddr - dev.mm.va_base, create_pts, free_pts, boot + self.pt_stack:list[tuple[Any, int, int]] = [(pt, self._pt_pte_idx(pt, vaddr), self._pt_pte_size(pt))] + + def _pt_pte_cnt(self, lv): return self.dev.mm.pte_cnt[lv] + def _pt_pte_size(self, pt): return self.dev.mm.pte_covers[pt.lv] + def _pt_pte_idx(self, pt, va): return (va // self._pt_pte_size(pt)) % self._pt_pte_cnt(pt.lv) + + def level_down(self): + pt, pte_idx, _ = self.pt_stack[-1] + + if not pt.valid(pte_idx): + assert self.create_pts, "Not allowed to create new page table" + pt.set_entry(pte_idx, self.dev.mm.palloc(0x1000, zero=True, boot=self.boot), table=True, valid=True) + + assert not pt.is_pte(pte_idx), f"Must be table pt={pt.paddr:#x}, {pt.lv=} {pte_idx=} {pt.read_fields(pte_idx)}" + child_page_table = self.dev.mm.pt_t(self.dev, pt.address(pte_idx), lv=pt.lv+1) + + self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table))) + return self.pt_stack[-1] + + def _try_free_pt(self) -> bool: + pt, _, _ = self.pt_stack[-1] + if self.free_pts and pt != self.dev.mm.root_page_table and all(not pt.valid(i) for i in range(self._pt_pte_cnt(self.pt_stack[-1][0].lv))): + self.dev.mm.pfree(pt.paddr) + parent_pt, parent_pte_idx, _ = self.pt_stack[-2] + parent_pt.set_entry(parent_pte_idx, 0x0, valid=False) + return True + return False + + def level_up(self): + while self._try_free_pt() or self.pt_stack[-1][1] == self._pt_pte_cnt(self.pt_stack[-1][0].lv): + pt, pt_cnt, _ = self.pt_stack.pop() + if pt_cnt == self._pt_pte_cnt(pt.lv): self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2]) + + def next(self, size:int, off=0): + while size > 0: + pt, pte_idx, pte_covers = self.pt_stack[-1] + if self.create_pts: + while pt.lv < self.dev.mm.first_page_lv or pte_covers > size or self.vaddr & (pte_covers-1) != 0: pt, pte_idx, pte_covers = self.level_down() + else: + while not pt.is_pte(pte_idx): pt, pte_idx, pte_covers = self.level_down() + + entries = min(size // pte_covers, self._pt_pte_cnt(pt.lv) - pte_idx) + assert entries > 0, f"Invalid entries {size=:#x}, {pte_covers=:#x}" + yield off, pt, pte_idx, entries, pte_covers + + size, off, self.vaddr = size - entries * pte_covers, off + entries * pte_covers, self.vaddr + entries * pte_covers + self.pt_stack[-1] = (pt, pte_idx + entries, pte_covers) + self.level_up() + +class MemoryManager: + va_allocator: ClassVar[TLSFAllocator|None] = None + + def __init__(self, dev, vram_size:int, boot_size:int, pt_t, pte_cnt:list[int], pte_covers:list[int], first_lv:int, first_page_lv:int, va_base:int): + self.dev, self.vram_size, self.va_base = dev, vram_size, va_base + self.pt_t, self.pte_cnt, self.pte_covers, self.first_page_lv = pt_t, pte_cnt, pte_covers, first_page_lv + + self.boot_allocator = TLSFAllocator(boot_size, base=0) # per device + self.pa_allocator = TLSFAllocator(vram_size - (64 << 20), base=self.boot_allocator.size) # per device + self.root_page_table = pt_t(self.dev, self.palloc(0x1000, zero=not self.dev.smi_dev, boot=True), lv=first_lv) + + def _frag_size(self, va, sz, must_cover=True): + """ + Calculate the tlb fragment size for a given virtual address and size. + If must_cover is True, the fragment size must cover the size, otherwise the biggest fragment size that fits the size is returned. + Fragment 0 is 4KB, 1 is 8KB and so on. + """ + va_pwr2_div, sz_pwr2_div, sz_pwr2_max = va & -(va) if va > 0 else (1 << 63), sz & -(sz), (1 << (sz.bit_length() - 1)) + return (min(va_pwr2_div, sz_pwr2_div) if must_cover else min(va_pwr2_div, sz_pwr2_max)).bit_length() - 1 - 12 + + def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False, boot=False) -> VirtMapping: + if getenv("MM_DEBUG", 0): print(f"mm {self.dev.devfmt}: mapping {vaddr=:#x} ({size=:#x})") + + assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}" + + ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, create_pts=True, boot=boot) + for paddr, psize in paddrs: + for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize): + for pte_off in range(pte_cnt): + assert not pt.valid(pte_idx + pte_off), f"PTE already mapped: {pt.entry(pte_idx + pte_off):#x}" + pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped, + frag=self._frag_size(ctx.vaddr+off, pte_cnt * pte_covers), valid=True) + + self.on_range_mapped() + return VirtMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped) + + def unmap_range(self, vaddr:int, size:int): + if getenv("MM_DEBUG", 0): print(f"mm {self.dev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})") + + ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, free_pts=True) + for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size): + for pte_id in range(pte_idx, pte_idx + pte_cnt): + assert pt.valid(pte_id), f"PTE not mapped: {pt.entry(pte_id):#x}" + pt.set_entry(pte_id, paddr=0x0, valid=False) + + def on_range_mapped(self): pass + + @classmethod + def alloc_vaddr(cls, size:int, align=0x1000) -> int: + assert cls.va_allocator is not None, "must be set it" + return cls.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align)) + + def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> VirtMapping: + # Alloc physical memory and map it to the virtual address + va = self.alloc_vaddr(size:=round_up(size, 0x1000), align) + + if contiguous: paddrs = [(self.palloc(size, zero=True), size)] + else: + # Traverse the PT to find the largest contiguous sizes we need to allocate. Try to allocate the longest segment to reduce TLB pressure. + paddrs = [] + ctx = PageTableTraverseContext(self.dev, self.root_page_table, va, create_pts=True) + for off, _, _, seg_cnt, seg_size in ctx.next(size): + rem_len = seg_cnt * seg_size + while rem_len > 0: + # Try to allocate as long segment (power of 2) as possible + cont_seg_sz, paddr = 1 << (self._frag_size(ctx.vaddr+off, rem_len) + 12), None + while cont_seg_sz >= 0x1000: + try: paddr = self.palloc(cont_seg_sz, zero=False) + except MemoryError: cont_seg_sz //= 2 + else: break + + if paddr is not None: paddrs += [(paddr, cont_seg_sz)] + else: + for paddr, _ in paddrs: self.pa_allocator.free(paddr) + raise MemoryError(f"Failed to allocate a contiguous page. (allocation size={size:#x})") + rem_len, off = rem_len - cont_seg_sz, off + cont_seg_sz + + return self.map_range(va, size, paddrs, uncached=uncached) + + def vfree(self, vm:VirtMapping): + assert self.va_allocator is not None, "must be set it" + self.unmap_range(vm.va_addr, vm.size) + self.va_allocator.free(vm.va_addr) + for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr) + + def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int: + assert self.dev.is_booting == boot, "During booting, only boot memory can be allocated" + paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align) + if zero: self.dev.vram[paddr:paddr+size] = bytes(size) + return paddr + + def pfree(self, paddr:int): self.pa_allocator.free(paddr)