factor out generic MemoryManager (#10910)

* allocator -> memory

* just moveout it

* mm is abstracted

* need entry abstraction

* fix

* mypy
This commit is contained in:
nimlgen
2025-06-21 16:18:33 +03:00
committed by GitHub
parent c7ec913210
commit 0e7bd9fd03
8 changed files with 274 additions and 246 deletions

View File

@@ -1,7 +1,7 @@
import random
from typing import Dict, Optional
from tinygrad.helpers import getenv
from tinygrad.runtime.support.allocator import TLSFAllocator
from tinygrad.runtime.support.memory import TLSFAllocator
class AllocatorFuzzer:
def __init__(self, total_size):

View File

@@ -1,7 +1,8 @@
import unittest
from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraverseContext
from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableEntry
from tinygrad.runtime.support.am.ip import AM_GMC
from tinygrad.runtime.support.hcq import MMIOInterface
from tinygrad.runtime.support.memory import PageTableTraverseContext
from tinygrad.runtime.autogen.am import am
from tinygrad.helpers import mv_address
@@ -23,7 +24,9 @@ class FakeAM:
self.vram_mv = memoryview(bytearray(4 << 30))
self.vram = MMIOInterface(mv_address(self.vram_mv), self.vram_mv.nbytes)
self.gmc = FakeGMC(self)
self.mm = AMMemoryManager(self, vram_size=4 << 30)
self.mm = AMMemoryManager(self, 4 << 30, boot_size=(32 << 20), pt_t=AMPageTableEntry, pte_cnt=[512, 512, 512, 512],
pte_covers=[(1 << ((9 * (3-lv)) + 12)) for lv in range(4)], first_lv=am.AMDGPU_VM_PDB1, first_page_lv=am.AMDGPU_VM_PDB2,
va_base=AMMemoryManager.va_allocator.base)
self.is_booting = False
self.ip_ver = {am.GC_HWIP: (11, 0, 0)}
def paddr2cpu(self, paddr:int) -> int: return paddr + mv_address(self.vram)
@@ -65,7 +68,7 @@ class TestAMPageTable(unittest.TestCase):
exteranl_va = va + AMMemoryManager.va_allocator.base
mm.map_range(vaddr=exteranl_va, size=sz, paddrs=[(va, sz)])
ctx = AMPageTableTraverseContext(self.d[0], mm.root_page_table, exteranl_va)
ctx = PageTableTraverseContext(self.d[0], mm.root_page_table, exteranl_va)
results = list(ctx.next(sz))
total_covered = 0
@@ -126,7 +129,7 @@ class TestAMPageTable(unittest.TestCase):
# Finally can map and check paddrs
mm0.map_range(vaddr=exteranl_va + 0x2000, size=0x100000, paddrs=[(0xdead0000, 0x1000), (0xdead1000, 0xff000)])
ctx = AMPageTableTraverseContext(self.d[0], mm0.root_page_table, exteranl_va + 0x2000)
ctx = PageTableTraverseContext(self.d[0], mm0.root_page_table, exteranl_va + 0x2000)
for tup in ctx.next(0x100000):
_offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup
for i in range(_n_ptes):

View File

@@ -1,5 +1,5 @@
import unittest
from tinygrad.runtime.support.allocator import TLSFAllocator
from tinygrad.runtime.support.memory import TLSFAllocator
class TestTLSFAllocator(unittest.TestCase):
def setUp(self):

View File

@@ -5,7 +5,7 @@ from tinygrad.device import Device, Buffer
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
from tinygrad.uop.ops import Ops
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.runtime.support.allocator import TLSFAllocator
from tinygrad.runtime.support.memory import TLSFAllocator
# **************** memory planning ****************

View File

@@ -14,9 +14,10 @@ from tinygrad.runtime.autogen import kfd, hsa, pci, sqtt
from tinygrad.runtime.autogen.am import am
from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.support.am.amdev import AMDev, AMMapping
from tinygrad.runtime.support.am.amdev import AMDev
from tinygrad.runtime.support.amd import AMDReg, AMDIP, import_module, setup_pci_bars
from tinygrad.runtime.support.system import System, PCIDevice, MAP_FIXED, MAP_NORESERVE
from tinygrad.runtime.support.memory import VirtMapping
from tinygrad.runtime.support.usb import ASM24Controller, USBMMIOInterface
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
@@ -635,7 +636,7 @@ class KFDIface:
raise RuntimeError("\n".join(report))
@dataclass
class AMAllocationMeta: owner:AMDDevice; mapped_devs:list[AMDDevice]; mapping:AMMapping; has_cpu_mapping:bool # noqa: E702
class AMAllocationMeta: owner:AMDDevice; mapped_devs:list[AMDDevice]; mapping:VirtMapping; has_cpu_mapping:bool # noqa: E702
class PCIIface:
gpus:list[Any] = []

View File

@@ -1,97 +0,0 @@
import collections, functools
from tinygrad.helpers import round_up
class TLSFAllocator:
"""
The allocator is based on the Two-Level Segregated Fit (TLSF) algorithm. The allocator maintains 2 level of buckets:
* 1st level is determined by the most significant bit of the size.
* 2nd level splits the covered memory of 1st level into @lv2_cnt entries.
For each allocation request, the allocator searches for the smallest block that can fit the requested size.
For each deallocation request, the allocator merges the block with its neighbors if they are free.
"""
def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16):
self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length()
self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)]
self.lv1_entries:list[int] = [0] * len(self.storage)
# self.blocks is more like a linked list, where each entry is a contiguous block.
self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free
self._insert_block(0, size)
@functools.cache
def lv1(self, size): return size.bit_length()
@functools.cache
def lv2(self, size): return (size - (1 << (size.bit_length() - 1))) // (1 << max(0, size.bit_length() - self.l2_cnt))
def _insert_block(self, start:int, size:int, prev:int|None=None):
if prev is None: prev = self.blocks[start][2]
self.storage[self.lv1(size)][self.lv2(size)].append(start)
self.lv1_entries[self.lv1(size)] += 1
self.blocks[start] = (size, start + size, prev, True)
return self
def _remove_block(self, start:int, size:int, prev:int|None=None):
if prev is None: prev = self.blocks[start][2]
self.storage[self.lv1(size)][self.lv2(size)].remove(start)
self.lv1_entries[self.lv1(size)] -= 1
self.blocks[start] = (size, start + size, prev, False)
return self
def _split_block(self, start:int, size:int, new_size:int):
nxt = self.blocks[start][1]
assert self.blocks[start][3], "block must be free"
self._remove_block(start, size)._insert_block(start, new_size)._insert_block(start + new_size, size - new_size, prev=start)
if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start + new_size, self.blocks[nxt][3])
return self
def _merge_right(self, start:int):
size, nxt, _, is_free = self.blocks[start]
assert is_free, "block must be free"
while is_free and nxt in self.blocks:
if (blk:=self.blocks[nxt])[3] is False: break
self._remove_block(start, size)._remove_block(nxt, blk[0])._insert_block(start, size:=size + blk[0])
assert self.blocks[start][1] == blk[1]
_, nxt, _, _ = self.blocks.pop(nxt)
if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start, self.blocks[nxt][3])
def _merge_block(self, start:int):
# Go left while blocks are free. Then merge all them right.
while (x:=self.blocks[start][2]) is not None and self.blocks[x][3] is True: start = x
self._merge_right(start)
def alloc(self, req_size:int, align:int=1) -> int:
req_size = max(self.block_size, req_size) # at least block size.
size = max(self.block_size, req_size + align - 1)
# Round up the allocation size to the next bucket, so any entry there can fit the requested size.
size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
# Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
for l1 in range(self.lv1(size), len(self.storage)):
if self.lv1_entries[l1] == 0: continue
for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
if len(self.storage[l1][l2]) > 0:
nsize = self.blocks[self.storage[l1][l2][0]][0]
assert nsize >= size, "block must be larger"
# Block start address.
start = self.storage[l1][l2][0]
# If request contains alignment, split the block into two parts.
if (new_start:=round_up(start, align)) != start:
self._split_block(start, nsize, new_start - start)
start, nsize = new_start, self.blocks[new_start][0]
# If the block is larger than the requested size, split it into two parts.
if nsize > req_size: self._split_block(start, nsize, req_size)
self._remove_block(start, req_size) # Mark the block as allocated.
return start + self.base
raise MemoryError(f"Can't allocate {req_size} bytes")
def free(self, start:int):
self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)

View File

@@ -1,10 +1,10 @@
from __future__ import annotations
import ctypes, collections, time, dataclasses, functools, fcntl, os, hashlib
from tinygrad.helpers import mv_address, getenv, round_up, DEBUG, temp, fetch
from tinygrad.helpers import mv_address, getenv, DEBUG, temp, fetch
from tinygrad.runtime.autogen.am import am
from tinygrad.runtime.support.hcq import MMIOInterface
from tinygrad.runtime.support.amd import AMDReg, import_module, import_asic_regs
from tinygrad.runtime.support.allocator import TLSFAllocator
from tinygrad.runtime.support.memory import TLSFAllocator, MemoryManager
from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
AM_DEBUG = getenv("AM_DEBUG", 0)
@@ -90,9 +90,6 @@ class AMFirmware:
def desc(self, blob:memoryview, offset:int, size:int, *types:int) -> tuple[list[int], memoryview]: return (list(types), blob[offset:offset+size])
@dataclasses.dataclass(frozen=True)
class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
class AMPageTableEntry:
def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.lv, self.entries = adev, paddr, lv, adev.vram.view(paddr, 0x1000, fmt='Q')
@@ -100,144 +97,18 @@ class AMPageTableEntry:
assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}"
self.entries[entry_id] = self.adev.gmc.get_pte_flags(self.lv, table, frag, uncached, system, snooped, valid) | (paddr & 0x0000FFFFFFFFF000)
class AMPageTableTraverseContext:
def __init__(self, adev, pt, vaddr, create_pts=False, free_pts=False, boot=False):
self.adev, self.vaddr, self.create_pts, self.free_pts, self.boot = adev, vaddr - adev.gmc.vm_base, create_pts, free_pts, boot
self.pt_stack:list[tuple[AMPageTableEntry, int, int]] = [(pt, self._pt_pte_idx(pt, vaddr), self._pt_pte_size(pt))]
def entry(self, entry_id:int) -> int: return self.entries[entry_id]
def valid(self, entry_id:int) -> bool: return (self.entries[entry_id] & am.AMDGPU_PTE_VALID) != 0
def address(self, entry_id:int) -> int: return self.entries[entry_id] & 0x0000FFFFFFFFF000
def is_pte(self, entry_id:int) -> bool: return self.lv == am.AMDGPU_VM_PTB or self.adev.gmc.is_pte_huge_page(self.entries[entry_id])
def _pt_pte_size(self, pt): return (1 << ((9 * (3-pt.lv)) + 12))
def _pt_pte_idx(self, pt, va): return (va // self._pt_pte_size(pt)) % 512
def level_down(self):
pt, pte_idx, _ = self.pt_stack[-1]
if (entry:=pt.entries[pte_idx]) & am.AMDGPU_PTE_VALID == 0:
assert self.create_pts, "Not allowed to create new page table"
pt.set_entry(pte_idx, self.adev.mm.palloc(0x1000, zero=True, boot=self.boot), table=True, valid=True)
entry = pt.entries[pte_idx]
assert not self.adev.gmc.is_pte_huge_page(entry), f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)
self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
return self.pt_stack[-1]
def _try_free_pt(self) -> bool:
pt, _, _ = self.pt_stack[-1]
if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.entries[i] & am.AMDGPU_PTE_VALID == 0 for i in range(512)):
self.adev.mm.pfree(pt.paddr)
parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
parent_pt.set_entry(parent_pte_idx, 0x0, valid=False)
return True
return False
def level_up(self):
while self._try_free_pt() or self.pt_stack[-1][1] == 512:
_, pt_cnt, _ = self.pt_stack.pop()
if pt_cnt == 512: self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2])
def next(self, size:int, off=0):
while size > 0:
pt, pte_idx, pte_covers = self.pt_stack[-1]
if self.create_pts:
while pte_covers > size or self.vaddr & (pte_covers-1) != 0: pt, pte_idx, pte_covers = self.level_down()
else:
while pt.lv!=am.AMDGPU_VM_PTB and not self.adev.gmc.is_pte_huge_page(pt.entries[pte_idx]): pt, pte_idx, pte_covers = self.level_down()
entries = min(size // pte_covers, 512 - pte_idx)
assert entries > 0, "Invalid entries"
yield off, pt, pte_idx, entries, pte_covers
size, off, self.vaddr = size - entries * pte_covers, off + entries * pte_covers, self.vaddr + entries * pte_covers
self.pt_stack[-1] = (pt, pte_idx + entries, pte_covers)
self.level_up()
class AMMemoryManager:
class AMMemoryManager(MemoryManager):
va_allocator = TLSFAllocator(512 * (1 << 30), base=0x200000000000) # global for all devices.
def __init__(self, adev:AMDev, vram_size:int):
self.adev, self.vram_size = adev, vram_size
self.boot_allocator = TLSFAllocator(32 << 20, base=0) # per device
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20), base=self.boot_allocator.size) # per device
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1)
def _frag_size(self, va, sz, must_cover=True):
"""
Calculate the tlb fragment size for a given virtual address and size.
If must_cover is True, the fragment size must cover the size, otherwise the biggest fragment size that fits the size is returned.
Fragment 0 is 4KB, 1 is 8KB and so on.
"""
va_pwr2_div, sz_pwr2_div, sz_pwr2_max = va & -(va) if va > 0 else (1 << 63), sz & -(sz), (1 << (sz.bit_length() - 1))
return (min(va_pwr2_div, sz_pwr2_div) if must_cover else min(va_pwr2_div, sz_pwr2_max)).bit_length() - 1 - 12
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False, boot=False) -> AMMapping:
if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, create_pts=True, boot=boot)
for paddr, psize in paddrs:
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize):
for pte_off in range(pte_cnt):
assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}"
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped,
frag=self._frag_size(ctx.vaddr+off, pte_cnt * pte_covers), valid=True)
def on_range_mapped(self):
# Invalidate TLB after mappings.
self.adev.gmc.flush_tlb(ip='GC', vmid=0)
self.adev.gmc.flush_tlb(ip='MM', vmid=0)
return AMMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped)
def unmap_range(self, vaddr:int, size:int):
if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})")
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, free_pts=True)
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):
for pte_id in range(pte_idx, pte_idx + pte_cnt):
assert pt.entries[pte_id] & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.entries[pte_id]:#x}"
pt.set_entry(pte_id, paddr=0x0, valid=False)
@staticmethod
def alloc_vaddr(size:int, align=0x1000) -> int: return AMMemoryManager.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> AMMapping:
# Alloc physical memory and map it to the virtual address
va = self.alloc_vaddr(size:=round_up(size, 0x1000), align)
if contiguous: paddrs = [(self.palloc(size, zero=True), size)]
else:
# Traverse the PT to find the largest contiguous sizes we need to allocate. Try to allocate the longest segment to reduce TLB pressure.
paddrs = []
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True)
for off, _, _, seg_cnt, seg_size in ctx.next(size):
rem_len = seg_cnt * seg_size
while rem_len > 0:
# Try to allocate as long segment (power of 2) as possible
cont_seg_sz, paddr = 1 << (self._frag_size(ctx.vaddr+off, rem_len) + 12), None
while cont_seg_sz >= 0x1000:
try: paddr = self.palloc(cont_seg_sz, zero=False)
except MemoryError: cont_seg_sz //= 2
else: break
if paddr is not None: paddrs += [(paddr, cont_seg_sz)]
else:
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
raise MemoryError(f"Failed to allocate a contiguous page. (allocation size={size:#x})")
rem_len, off = rem_len - cont_seg_sz, off + cont_seg_sz
return self.map_range(va, size, paddrs, uncached=uncached)
def vfree(self, vm:AMMapping):
self.unmap_range(vm.va_addr, vm.size)
self.va_allocator.free(vm.va_addr)
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int:
assert self.adev.is_booting == boot, "During booting, only boot memory can be allocated"
paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align)
if zero: self.adev.vram[paddr:paddr+size] = bytes(size)
return paddr
def pfree(self, paddr:int): self.pa_allocator.free(paddr)
self.dev.gmc.flush_tlb(ip='GC', vmid=0)
self.dev.gmc.flush_tlb(ip='MM', vmid=0)
class AMDev:
def __init__(self, devfmt, vram:MMIOInterface, doorbell:MMIOInterface, mmio:MMIOInterface, dma_regions:list[tuple[int, MMIOInterface]]|None=None):
@@ -268,7 +139,9 @@ class AMDev:
self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000005)) and (getenv("AM_RESET", 0) != 1)
# Memory manager & firmware
self.mm = AMMemoryManager(self, self.vram_size)
self.mm = AMMemoryManager(self, self.vram_size, boot_size=(32 << 20), pt_t=AMPageTableEntry, pte_cnt=[512, 512, 512, 512],
pte_covers=[(1 << ((9 * (3-lv)) + 12)) for lv in range(4)], first_lv=am.AMDGPU_VM_PDB1, first_page_lv=am.AMDGPU_VM_PDB2,
va_base=AMMemoryManager.va_allocator.base)
self.fw = AMFirmware(self)
# Initialize IP blocks

View File

@@ -0,0 +1,248 @@
import collections, functools, dataclasses
from typing import Any, ClassVar
from tinygrad.helpers import round_up, getenv
class TLSFAllocator:
"""
The allocator is based on the Two-Level Segregated Fit (TLSF) algorithm. The allocator maintains 2 level of buckets:
* 1st level is determined by the most significant bit of the size.
* 2nd level splits the covered memory of 1st level into @lv2_cnt entries.
For each allocation request, the allocator searches for the smallest block that can fit the requested size.
For each deallocation request, the allocator merges the block with its neighbors if they are free.
"""
def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16):
self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length()
self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)]
self.lv1_entries:list[int] = [0] * len(self.storage)
# self.blocks is more like a linked list, where each entry is a contiguous block.
self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free
self._insert_block(0, size)
@functools.cache
def lv1(self, size): return size.bit_length()
@functools.cache
def lv2(self, size): return (size - (1 << (size.bit_length() - 1))) // (1 << max(0, size.bit_length() - self.l2_cnt))
def _insert_block(self, start:int, size:int, prev:int|None=None):
if prev is None: prev = self.blocks[start][2]
self.storage[self.lv1(size)][self.lv2(size)].append(start)
self.lv1_entries[self.lv1(size)] += 1
self.blocks[start] = (size, start + size, prev, True)
return self
def _remove_block(self, start:int, size:int, prev:int|None=None):
if prev is None: prev = self.blocks[start][2]
self.storage[self.lv1(size)][self.lv2(size)].remove(start)
self.lv1_entries[self.lv1(size)] -= 1
self.blocks[start] = (size, start + size, prev, False)
return self
def _split_block(self, start:int, size:int, new_size:int):
nxt = self.blocks[start][1]
assert self.blocks[start][3], "block must be free"
self._remove_block(start, size)._insert_block(start, new_size)._insert_block(start + new_size, size - new_size, prev=start)
if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start + new_size, self.blocks[nxt][3])
return self
def _merge_right(self, start:int):
size, nxt, _, is_free = self.blocks[start]
assert is_free, "block must be free"
while is_free and nxt in self.blocks:
if (blk:=self.blocks[nxt])[3] is False: break
self._remove_block(start, size)._remove_block(nxt, blk[0])._insert_block(start, size:=size + blk[0])
assert self.blocks[start][1] == blk[1]
_, nxt, _, _ = self.blocks.pop(nxt)
if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start, self.blocks[nxt][3])
def _merge_block(self, start:int):
# Go left while blocks are free. Then merge all them right.
while (x:=self.blocks[start][2]) is not None and self.blocks[x][3] is True: start = x
self._merge_right(start)
def alloc(self, req_size:int, align:int=1) -> int:
req_size = max(self.block_size, req_size) # at least block size.
size = max(self.block_size, req_size + align - 1)
# Round up the allocation size to the next bucket, so any entry there can fit the requested size.
size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
# Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
for l1 in range(self.lv1(size), len(self.storage)):
if self.lv1_entries[l1] == 0: continue
for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
if len(self.storage[l1][l2]) > 0:
nsize = self.blocks[self.storage[l1][l2][0]][0]
assert nsize >= size, "block must be larger"
# Block start address.
start = self.storage[l1][l2][0]
# If request contains alignment, split the block into two parts.
if (new_start:=round_up(start, align)) != start:
self._split_block(start, nsize, new_start - start)
start, nsize = new_start, self.blocks[new_start][0]
# If the block is larger than the requested size, split it into two parts.
if nsize > req_size: self._split_block(start, nsize, req_size)
self._remove_block(start, req_size) # Mark the block as allocated.
return start + self.base
raise MemoryError(f"Can't allocate {req_size} bytes")
def free(self, start:int):
self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)
# Memory Managment
@dataclasses.dataclass(frozen=True)
class VirtMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
class PageTableTraverseContext:
def __init__(self, dev, pt, vaddr, create_pts=False, free_pts=False, boot=False):
self.dev, self.vaddr, self.create_pts, self.free_pts, self.boot = dev, vaddr - dev.mm.va_base, create_pts, free_pts, boot
self.pt_stack:list[tuple[Any, int, int]] = [(pt, self._pt_pte_idx(pt, vaddr), self._pt_pte_size(pt))]
def _pt_pte_cnt(self, lv): return self.dev.mm.pte_cnt[lv]
def _pt_pte_size(self, pt): return self.dev.mm.pte_covers[pt.lv]
def _pt_pte_idx(self, pt, va): return (va // self._pt_pte_size(pt)) % self._pt_pte_cnt(pt.lv)
def level_down(self):
pt, pte_idx, _ = self.pt_stack[-1]
if not pt.valid(pte_idx):
assert self.create_pts, "Not allowed to create new page table"
pt.set_entry(pte_idx, self.dev.mm.palloc(0x1000, zero=True, boot=self.boot), table=True, valid=True)
assert not pt.is_pte(pte_idx), f"Must be table pt={pt.paddr:#x}, {pt.lv=} {pte_idx=} {pt.read_fields(pte_idx)}"
child_page_table = self.dev.mm.pt_t(self.dev, pt.address(pte_idx), lv=pt.lv+1)
self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
return self.pt_stack[-1]
def _try_free_pt(self) -> bool:
pt, _, _ = self.pt_stack[-1]
if self.free_pts and pt != self.dev.mm.root_page_table and all(not pt.valid(i) for i in range(self._pt_pte_cnt(self.pt_stack[-1][0].lv))):
self.dev.mm.pfree(pt.paddr)
parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
parent_pt.set_entry(parent_pte_idx, 0x0, valid=False)
return True
return False
def level_up(self):
while self._try_free_pt() or self.pt_stack[-1][1] == self._pt_pte_cnt(self.pt_stack[-1][0].lv):
pt, pt_cnt, _ = self.pt_stack.pop()
if pt_cnt == self._pt_pte_cnt(pt.lv): self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2])
def next(self, size:int, off=0):
while size > 0:
pt, pte_idx, pte_covers = self.pt_stack[-1]
if self.create_pts:
while pt.lv < self.dev.mm.first_page_lv or pte_covers > size or self.vaddr & (pte_covers-1) != 0: pt, pte_idx, pte_covers = self.level_down()
else:
while not pt.is_pte(pte_idx): pt, pte_idx, pte_covers = self.level_down()
entries = min(size // pte_covers, self._pt_pte_cnt(pt.lv) - pte_idx)
assert entries > 0, f"Invalid entries {size=:#x}, {pte_covers=:#x}"
yield off, pt, pte_idx, entries, pte_covers
size, off, self.vaddr = size - entries * pte_covers, off + entries * pte_covers, self.vaddr + entries * pte_covers
self.pt_stack[-1] = (pt, pte_idx + entries, pte_covers)
self.level_up()
class MemoryManager:
va_allocator: ClassVar[TLSFAllocator|None] = None
def __init__(self, dev, vram_size:int, boot_size:int, pt_t, pte_cnt:list[int], pte_covers:list[int], first_lv:int, first_page_lv:int, va_base:int):
self.dev, self.vram_size, self.va_base = dev, vram_size, va_base
self.pt_t, self.pte_cnt, self.pte_covers, self.first_page_lv = pt_t, pte_cnt, pte_covers, first_page_lv
self.boot_allocator = TLSFAllocator(boot_size, base=0) # per device
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20), base=self.boot_allocator.size) # per device
self.root_page_table = pt_t(self.dev, self.palloc(0x1000, zero=not self.dev.smi_dev, boot=True), lv=first_lv)
def _frag_size(self, va, sz, must_cover=True):
"""
Calculate the tlb fragment size for a given virtual address and size.
If must_cover is True, the fragment size must cover the size, otherwise the biggest fragment size that fits the size is returned.
Fragment 0 is 4KB, 1 is 8KB and so on.
"""
va_pwr2_div, sz_pwr2_div, sz_pwr2_max = va & -(va) if va > 0 else (1 << 63), sz & -(sz), (1 << (sz.bit_length() - 1))
return (min(va_pwr2_div, sz_pwr2_div) if must_cover else min(va_pwr2_div, sz_pwr2_max)).bit_length() - 1 - 12
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False, boot=False) -> VirtMapping:
if getenv("MM_DEBUG", 0): print(f"mm {self.dev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, create_pts=True, boot=boot)
for paddr, psize in paddrs:
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize):
for pte_off in range(pte_cnt):
assert not pt.valid(pte_idx + pte_off), f"PTE already mapped: {pt.entry(pte_idx + pte_off):#x}"
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped,
frag=self._frag_size(ctx.vaddr+off, pte_cnt * pte_covers), valid=True)
self.on_range_mapped()
return VirtMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped)
def unmap_range(self, vaddr:int, size:int):
if getenv("MM_DEBUG", 0): print(f"mm {self.dev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})")
ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, free_pts=True)
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):
for pte_id in range(pte_idx, pte_idx + pte_cnt):
assert pt.valid(pte_id), f"PTE not mapped: {pt.entry(pte_id):#x}"
pt.set_entry(pte_id, paddr=0x0, valid=False)
def on_range_mapped(self): pass
@classmethod
def alloc_vaddr(cls, size:int, align=0x1000) -> int:
assert cls.va_allocator is not None, "must be set it"
return cls.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> VirtMapping:
# Alloc physical memory and map it to the virtual address
va = self.alloc_vaddr(size:=round_up(size, 0x1000), align)
if contiguous: paddrs = [(self.palloc(size, zero=True), size)]
else:
# Traverse the PT to find the largest contiguous sizes we need to allocate. Try to allocate the longest segment to reduce TLB pressure.
paddrs = []
ctx = PageTableTraverseContext(self.dev, self.root_page_table, va, create_pts=True)
for off, _, _, seg_cnt, seg_size in ctx.next(size):
rem_len = seg_cnt * seg_size
while rem_len > 0:
# Try to allocate as long segment (power of 2) as possible
cont_seg_sz, paddr = 1 << (self._frag_size(ctx.vaddr+off, rem_len) + 12), None
while cont_seg_sz >= 0x1000:
try: paddr = self.palloc(cont_seg_sz, zero=False)
except MemoryError: cont_seg_sz //= 2
else: break
if paddr is not None: paddrs += [(paddr, cont_seg_sz)]
else:
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
raise MemoryError(f"Failed to allocate a contiguous page. (allocation size={size:#x})")
rem_len, off = rem_len - cont_seg_sz, off + cont_seg_sz
return self.map_range(va, size, paddrs, uncached=uncached)
def vfree(self, vm:VirtMapping):
assert self.va_allocator is not None, "must be set it"
self.unmap_range(vm.va_addr, vm.size)
self.va_allocator.free(vm.va_addr)
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int:
assert self.dev.is_booting == boot, "During booting, only boot memory can be allocated"
paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align)
if zero: self.dev.vram[paddr:paddr+size] = bytes(size)
return paddr
def pfree(self, paddr:int): self.pa_allocator.free(paddr)