mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
factor out generic MemoryManager (#10910)
* allocator -> memory * just moveout it * mm is abstracted * need entry abstraction * fix * mypy
This commit is contained in:
2
test/external/external_fuzz_tlsf.py
vendored
2
test/external/external_fuzz_tlsf.py
vendored
@@ -1,7 +1,7 @@
|
||||
import random
|
||||
from typing import Dict, Optional
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.runtime.support.allocator import TLSFAllocator
|
||||
from tinygrad.runtime.support.memory import TLSFAllocator
|
||||
|
||||
class AllocatorFuzzer:
|
||||
def __init__(self, total_size):
|
||||
|
||||
11
test/external/external_test_am.py
vendored
11
test/external/external_test_am.py
vendored
@@ -1,7 +1,8 @@
|
||||
import unittest
|
||||
from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraverseContext
|
||||
from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableEntry
|
||||
from tinygrad.runtime.support.am.ip import AM_GMC
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
from tinygrad.runtime.support.memory import PageTableTraverseContext
|
||||
from tinygrad.runtime.autogen.am import am
|
||||
from tinygrad.helpers import mv_address
|
||||
|
||||
@@ -23,7 +24,9 @@ class FakeAM:
|
||||
self.vram_mv = memoryview(bytearray(4 << 30))
|
||||
self.vram = MMIOInterface(mv_address(self.vram_mv), self.vram_mv.nbytes)
|
||||
self.gmc = FakeGMC(self)
|
||||
self.mm = AMMemoryManager(self, vram_size=4 << 30)
|
||||
self.mm = AMMemoryManager(self, 4 << 30, boot_size=(32 << 20), pt_t=AMPageTableEntry, pte_cnt=[512, 512, 512, 512],
|
||||
pte_covers=[(1 << ((9 * (3-lv)) + 12)) for lv in range(4)], first_lv=am.AMDGPU_VM_PDB1, first_page_lv=am.AMDGPU_VM_PDB2,
|
||||
va_base=AMMemoryManager.va_allocator.base)
|
||||
self.is_booting = False
|
||||
self.ip_ver = {am.GC_HWIP: (11, 0, 0)}
|
||||
def paddr2cpu(self, paddr:int) -> int: return paddr + mv_address(self.vram)
|
||||
@@ -65,7 +68,7 @@ class TestAMPageTable(unittest.TestCase):
|
||||
exteranl_va = va + AMMemoryManager.va_allocator.base
|
||||
mm.map_range(vaddr=exteranl_va, size=sz, paddrs=[(va, sz)])
|
||||
|
||||
ctx = AMPageTableTraverseContext(self.d[0], mm.root_page_table, exteranl_va)
|
||||
ctx = PageTableTraverseContext(self.d[0], mm.root_page_table, exteranl_va)
|
||||
results = list(ctx.next(sz))
|
||||
|
||||
total_covered = 0
|
||||
@@ -126,7 +129,7 @@ class TestAMPageTable(unittest.TestCase):
|
||||
# Finally can map and check paddrs
|
||||
mm0.map_range(vaddr=exteranl_va + 0x2000, size=0x100000, paddrs=[(0xdead0000, 0x1000), (0xdead1000, 0xff000)])
|
||||
|
||||
ctx = AMPageTableTraverseContext(self.d[0], mm0.root_page_table, exteranl_va + 0x2000)
|
||||
ctx = PageTableTraverseContext(self.d[0], mm0.root_page_table, exteranl_va + 0x2000)
|
||||
for tup in ctx.next(0x100000):
|
||||
_offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup
|
||||
for i in range(_n_ptes):
|
||||
|
||||
2
test/external/external_test_tlsf.py
vendored
2
test/external/external_test_tlsf.py
vendored
@@ -1,5 +1,5 @@
|
||||
import unittest
|
||||
from tinygrad.runtime.support.allocator import TLSFAllocator
|
||||
from tinygrad.runtime.support.memory import TLSFAllocator
|
||||
|
||||
class TestTLSFAllocator(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.device import Device, Buffer
|
||||
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.runtime.support.allocator import TLSFAllocator
|
||||
from tinygrad.runtime.support.memory import TLSFAllocator
|
||||
|
||||
# **************** memory planning ****************
|
||||
|
||||
|
||||
@@ -14,9 +14,10 @@ from tinygrad.runtime.autogen import kfd, hsa, pci, sqtt
|
||||
from tinygrad.runtime.autogen.am import am
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from tinygrad.runtime.support.am.amdev import AMDev, AMMapping
|
||||
from tinygrad.runtime.support.am.amdev import AMDev
|
||||
from tinygrad.runtime.support.amd import AMDReg, AMDIP, import_module, setup_pci_bars
|
||||
from tinygrad.runtime.support.system import System, PCIDevice, MAP_FIXED, MAP_NORESERVE
|
||||
from tinygrad.runtime.support.memory import VirtMapping
|
||||
from tinygrad.runtime.support.usb import ASM24Controller, USBMMIOInterface
|
||||
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
@@ -635,7 +636,7 @@ class KFDIface:
|
||||
raise RuntimeError("\n".join(report))
|
||||
|
||||
@dataclass
|
||||
class AMAllocationMeta: owner:AMDDevice; mapped_devs:list[AMDDevice]; mapping:AMMapping; has_cpu_mapping:bool # noqa: E702
|
||||
class AMAllocationMeta: owner:AMDDevice; mapped_devs:list[AMDDevice]; mapping:VirtMapping; has_cpu_mapping:bool # noqa: E702
|
||||
|
||||
class PCIIface:
|
||||
gpus:list[Any] = []
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import collections, functools
|
||||
from tinygrad.helpers import round_up
|
||||
|
||||
class TLSFAllocator:
|
||||
"""
|
||||
The allocator is based on the Two-Level Segregated Fit (TLSF) algorithm. The allocator maintains 2 level of buckets:
|
||||
* 1st level is determined by the most significant bit of the size.
|
||||
* 2nd level splits the covered memory of 1st level into @lv2_cnt entries.
|
||||
|
||||
For each allocation request, the allocator searches for the smallest block that can fit the requested size.
|
||||
For each deallocation request, the allocator merges the block with its neighbors if they are free.
|
||||
"""
|
||||
|
||||
def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16):
|
||||
self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length()
|
||||
self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)]
|
||||
self.lv1_entries:list[int] = [0] * len(self.storage)
|
||||
|
||||
# self.blocks is more like a linked list, where each entry is a contiguous block.
|
||||
self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free
|
||||
self._insert_block(0, size)
|
||||
|
||||
@functools.cache
|
||||
def lv1(self, size): return size.bit_length()
|
||||
|
||||
@functools.cache
|
||||
def lv2(self, size): return (size - (1 << (size.bit_length() - 1))) // (1 << max(0, size.bit_length() - self.l2_cnt))
|
||||
|
||||
def _insert_block(self, start:int, size:int, prev:int|None=None):
|
||||
if prev is None: prev = self.blocks[start][2]
|
||||
self.storage[self.lv1(size)][self.lv2(size)].append(start)
|
||||
self.lv1_entries[self.lv1(size)] += 1
|
||||
self.blocks[start] = (size, start + size, prev, True)
|
||||
return self
|
||||
|
||||
def _remove_block(self, start:int, size:int, prev:int|None=None):
|
||||
if prev is None: prev = self.blocks[start][2]
|
||||
self.storage[self.lv1(size)][self.lv2(size)].remove(start)
|
||||
self.lv1_entries[self.lv1(size)] -= 1
|
||||
self.blocks[start] = (size, start + size, prev, False)
|
||||
return self
|
||||
|
||||
def _split_block(self, start:int, size:int, new_size:int):
|
||||
nxt = self.blocks[start][1]
|
||||
assert self.blocks[start][3], "block must be free"
|
||||
self._remove_block(start, size)._insert_block(start, new_size)._insert_block(start + new_size, size - new_size, prev=start)
|
||||
if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start + new_size, self.blocks[nxt][3])
|
||||
return self
|
||||
|
||||
def _merge_right(self, start:int):
|
||||
size, nxt, _, is_free = self.blocks[start]
|
||||
assert is_free, "block must be free"
|
||||
|
||||
while is_free and nxt in self.blocks:
|
||||
if (blk:=self.blocks[nxt])[3] is False: break
|
||||
self._remove_block(start, size)._remove_block(nxt, blk[0])._insert_block(start, size:=size + blk[0])
|
||||
assert self.blocks[start][1] == blk[1]
|
||||
_, nxt, _, _ = self.blocks.pop(nxt)
|
||||
|
||||
if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start, self.blocks[nxt][3])
|
||||
|
||||
def _merge_block(self, start:int):
|
||||
# Go left while blocks are free. Then merge all them right.
|
||||
while (x:=self.blocks[start][2]) is not None and self.blocks[x][3] is True: start = x
|
||||
self._merge_right(start)
|
||||
|
||||
def alloc(self, req_size:int, align:int=1) -> int:
|
||||
req_size = max(self.block_size, req_size) # at least block size.
|
||||
size = max(self.block_size, req_size + align - 1)
|
||||
|
||||
# Round up the allocation size to the next bucket, so any entry there can fit the requested size.
|
||||
size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
|
||||
|
||||
# Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
|
||||
for l1 in range(self.lv1(size), len(self.storage)):
|
||||
if self.lv1_entries[l1] == 0: continue
|
||||
for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
|
||||
if len(self.storage[l1][l2]) > 0:
|
||||
nsize = self.blocks[self.storage[l1][l2][0]][0]
|
||||
assert nsize >= size, "block must be larger"
|
||||
|
||||
# Block start address.
|
||||
start = self.storage[l1][l2][0]
|
||||
|
||||
# If request contains alignment, split the block into two parts.
|
||||
if (new_start:=round_up(start, align)) != start:
|
||||
self._split_block(start, nsize, new_start - start)
|
||||
start, nsize = new_start, self.blocks[new_start][0]
|
||||
|
||||
# If the block is larger than the requested size, split it into two parts.
|
||||
if nsize > req_size: self._split_block(start, nsize, req_size)
|
||||
self._remove_block(start, req_size) # Mark the block as allocated.
|
||||
return start + self.base
|
||||
raise MemoryError(f"Can't allocate {req_size} bytes")
|
||||
|
||||
def free(self, start:int):
|
||||
self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)
|
||||
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, collections, time, dataclasses, functools, fcntl, os, hashlib
|
||||
from tinygrad.helpers import mv_address, getenv, round_up, DEBUG, temp, fetch
|
||||
from tinygrad.helpers import mv_address, getenv, DEBUG, temp, fetch
|
||||
from tinygrad.runtime.autogen.am import am
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
from tinygrad.runtime.support.amd import AMDReg, import_module, import_asic_regs
|
||||
from tinygrad.runtime.support.allocator import TLSFAllocator
|
||||
from tinygrad.runtime.support.memory import TLSFAllocator, MemoryManager
|
||||
from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
|
||||
|
||||
AM_DEBUG = getenv("AM_DEBUG", 0)
|
||||
@@ -90,9 +90,6 @@ class AMFirmware:
|
||||
|
||||
def desc(self, blob:memoryview, offset:int, size:int, *types:int) -> tuple[list[int], memoryview]: return (list(types), blob[offset:offset+size])
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
|
||||
|
||||
class AMPageTableEntry:
|
||||
def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.lv, self.entries = adev, paddr, lv, adev.vram.view(paddr, 0x1000, fmt='Q')
|
||||
|
||||
@@ -100,144 +97,18 @@ class AMPageTableEntry:
|
||||
assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}"
|
||||
self.entries[entry_id] = self.adev.gmc.get_pte_flags(self.lv, table, frag, uncached, system, snooped, valid) | (paddr & 0x0000FFFFFFFFF000)
|
||||
|
||||
class AMPageTableTraverseContext:
|
||||
def __init__(self, adev, pt, vaddr, create_pts=False, free_pts=False, boot=False):
|
||||
self.adev, self.vaddr, self.create_pts, self.free_pts, self.boot = adev, vaddr - adev.gmc.vm_base, create_pts, free_pts, boot
|
||||
self.pt_stack:list[tuple[AMPageTableEntry, int, int]] = [(pt, self._pt_pte_idx(pt, vaddr), self._pt_pte_size(pt))]
|
||||
def entry(self, entry_id:int) -> int: return self.entries[entry_id]
|
||||
def valid(self, entry_id:int) -> bool: return (self.entries[entry_id] & am.AMDGPU_PTE_VALID) != 0
|
||||
def address(self, entry_id:int) -> int: return self.entries[entry_id] & 0x0000FFFFFFFFF000
|
||||
def is_pte(self, entry_id:int) -> bool: return self.lv == am.AMDGPU_VM_PTB or self.adev.gmc.is_pte_huge_page(self.entries[entry_id])
|
||||
|
||||
def _pt_pte_size(self, pt): return (1 << ((9 * (3-pt.lv)) + 12))
|
||||
def _pt_pte_idx(self, pt, va): return (va // self._pt_pte_size(pt)) % 512
|
||||
|
||||
def level_down(self):
|
||||
pt, pte_idx, _ = self.pt_stack[-1]
|
||||
if (entry:=pt.entries[pte_idx]) & am.AMDGPU_PTE_VALID == 0:
|
||||
assert self.create_pts, "Not allowed to create new page table"
|
||||
pt.set_entry(pte_idx, self.adev.mm.palloc(0x1000, zero=True, boot=self.boot), table=True, valid=True)
|
||||
entry = pt.entries[pte_idx]
|
||||
|
||||
assert not self.adev.gmc.is_pte_huge_page(entry), f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
|
||||
child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)
|
||||
|
||||
self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
|
||||
return self.pt_stack[-1]
|
||||
|
||||
def _try_free_pt(self) -> bool:
|
||||
pt, _, _ = self.pt_stack[-1]
|
||||
if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.entries[i] & am.AMDGPU_PTE_VALID == 0 for i in range(512)):
|
||||
self.adev.mm.pfree(pt.paddr)
|
||||
parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
|
||||
parent_pt.set_entry(parent_pte_idx, 0x0, valid=False)
|
||||
return True
|
||||
return False
|
||||
|
||||
def level_up(self):
|
||||
while self._try_free_pt() or self.pt_stack[-1][1] == 512:
|
||||
_, pt_cnt, _ = self.pt_stack.pop()
|
||||
if pt_cnt == 512: self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2])
|
||||
|
||||
def next(self, size:int, off=0):
|
||||
while size > 0:
|
||||
pt, pte_idx, pte_covers = self.pt_stack[-1]
|
||||
if self.create_pts:
|
||||
while pte_covers > size or self.vaddr & (pte_covers-1) != 0: pt, pte_idx, pte_covers = self.level_down()
|
||||
else:
|
||||
while pt.lv!=am.AMDGPU_VM_PTB and not self.adev.gmc.is_pte_huge_page(pt.entries[pte_idx]): pt, pte_idx, pte_covers = self.level_down()
|
||||
|
||||
entries = min(size // pte_covers, 512 - pte_idx)
|
||||
assert entries > 0, "Invalid entries"
|
||||
yield off, pt, pte_idx, entries, pte_covers
|
||||
|
||||
size, off, self.vaddr = size - entries * pte_covers, off + entries * pte_covers, self.vaddr + entries * pte_covers
|
||||
self.pt_stack[-1] = (pt, pte_idx + entries, pte_covers)
|
||||
self.level_up()
|
||||
|
||||
class AMMemoryManager:
|
||||
class AMMemoryManager(MemoryManager):
|
||||
va_allocator = TLSFAllocator(512 * (1 << 30), base=0x200000000000) # global for all devices.
|
||||
|
||||
def __init__(self, adev:AMDev, vram_size:int):
|
||||
self.adev, self.vram_size = adev, vram_size
|
||||
self.boot_allocator = TLSFAllocator(32 << 20, base=0) # per device
|
||||
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20), base=self.boot_allocator.size) # per device
|
||||
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1)
|
||||
|
||||
def _frag_size(self, va, sz, must_cover=True):
|
||||
"""
|
||||
Calculate the tlb fragment size for a given virtual address and size.
|
||||
If must_cover is True, the fragment size must cover the size, otherwise the biggest fragment size that fits the size is returned.
|
||||
Fragment 0 is 4KB, 1 is 8KB and so on.
|
||||
"""
|
||||
va_pwr2_div, sz_pwr2_div, sz_pwr2_max = va & -(va) if va > 0 else (1 << 63), sz & -(sz), (1 << (sz.bit_length() - 1))
|
||||
return (min(va_pwr2_div, sz_pwr2_div) if must_cover else min(va_pwr2_div, sz_pwr2_max)).bit_length() - 1 - 12
|
||||
|
||||
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False, boot=False) -> AMMapping:
|
||||
if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
|
||||
|
||||
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
|
||||
|
||||
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, create_pts=True, boot=boot)
|
||||
for paddr, psize in paddrs:
|
||||
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize):
|
||||
for pte_off in range(pte_cnt):
|
||||
assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}"
|
||||
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped,
|
||||
frag=self._frag_size(ctx.vaddr+off, pte_cnt * pte_covers), valid=True)
|
||||
|
||||
def on_range_mapped(self):
|
||||
# Invalidate TLB after mappings.
|
||||
self.adev.gmc.flush_tlb(ip='GC', vmid=0)
|
||||
self.adev.gmc.flush_tlb(ip='MM', vmid=0)
|
||||
return AMMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped)
|
||||
|
||||
def unmap_range(self, vaddr:int, size:int):
|
||||
if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})")
|
||||
|
||||
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, free_pts=True)
|
||||
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):
|
||||
for pte_id in range(pte_idx, pte_idx + pte_cnt):
|
||||
assert pt.entries[pte_id] & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.entries[pte_id]:#x}"
|
||||
pt.set_entry(pte_id, paddr=0x0, valid=False)
|
||||
|
||||
@staticmethod
|
||||
def alloc_vaddr(size:int, align=0x1000) -> int: return AMMemoryManager.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
|
||||
|
||||
def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> AMMapping:
|
||||
# Alloc physical memory and map it to the virtual address
|
||||
va = self.alloc_vaddr(size:=round_up(size, 0x1000), align)
|
||||
|
||||
if contiguous: paddrs = [(self.palloc(size, zero=True), size)]
|
||||
else:
|
||||
# Traverse the PT to find the largest contiguous sizes we need to allocate. Try to allocate the longest segment to reduce TLB pressure.
|
||||
paddrs = []
|
||||
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True)
|
||||
for off, _, _, seg_cnt, seg_size in ctx.next(size):
|
||||
rem_len = seg_cnt * seg_size
|
||||
while rem_len > 0:
|
||||
# Try to allocate as long segment (power of 2) as possible
|
||||
cont_seg_sz, paddr = 1 << (self._frag_size(ctx.vaddr+off, rem_len) + 12), None
|
||||
while cont_seg_sz >= 0x1000:
|
||||
try: paddr = self.palloc(cont_seg_sz, zero=False)
|
||||
except MemoryError: cont_seg_sz //= 2
|
||||
else: break
|
||||
|
||||
if paddr is not None: paddrs += [(paddr, cont_seg_sz)]
|
||||
else:
|
||||
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
|
||||
raise MemoryError(f"Failed to allocate a contiguous page. (allocation size={size:#x})")
|
||||
rem_len, off = rem_len - cont_seg_sz, off + cont_seg_sz
|
||||
|
||||
return self.map_range(va, size, paddrs, uncached=uncached)
|
||||
|
||||
def vfree(self, vm:AMMapping):
|
||||
self.unmap_range(vm.va_addr, vm.size)
|
||||
self.va_allocator.free(vm.va_addr)
|
||||
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
|
||||
|
||||
def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int:
|
||||
assert self.adev.is_booting == boot, "During booting, only boot memory can be allocated"
|
||||
paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align)
|
||||
if zero: self.adev.vram[paddr:paddr+size] = bytes(size)
|
||||
return paddr
|
||||
|
||||
def pfree(self, paddr:int): self.pa_allocator.free(paddr)
|
||||
self.dev.gmc.flush_tlb(ip='GC', vmid=0)
|
||||
self.dev.gmc.flush_tlb(ip='MM', vmid=0)
|
||||
|
||||
class AMDev:
|
||||
def __init__(self, devfmt, vram:MMIOInterface, doorbell:MMIOInterface, mmio:MMIOInterface, dma_regions:list[tuple[int, MMIOInterface]]|None=None):
|
||||
@@ -268,7 +139,9 @@ class AMDev:
|
||||
self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000005)) and (getenv("AM_RESET", 0) != 1)
|
||||
|
||||
# Memory manager & firmware
|
||||
self.mm = AMMemoryManager(self, self.vram_size)
|
||||
self.mm = AMMemoryManager(self, self.vram_size, boot_size=(32 << 20), pt_t=AMPageTableEntry, pte_cnt=[512, 512, 512, 512],
|
||||
pte_covers=[(1 << ((9 * (3-lv)) + 12)) for lv in range(4)], first_lv=am.AMDGPU_VM_PDB1, first_page_lv=am.AMDGPU_VM_PDB2,
|
||||
va_base=AMMemoryManager.va_allocator.base)
|
||||
self.fw = AMFirmware(self)
|
||||
|
||||
# Initialize IP blocks
|
||||
|
||||
248
tinygrad/runtime/support/memory.py
Normal file
248
tinygrad/runtime/support/memory.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import collections, functools, dataclasses
|
||||
from typing import Any, ClassVar
|
||||
from tinygrad.helpers import round_up, getenv
|
||||
|
||||
class TLSFAllocator:
|
||||
"""
|
||||
The allocator is based on the Two-Level Segregated Fit (TLSF) algorithm. The allocator maintains 2 level of buckets:
|
||||
* 1st level is determined by the most significant bit of the size.
|
||||
* 2nd level splits the covered memory of 1st level into @lv2_cnt entries.
|
||||
|
||||
For each allocation request, the allocator searches for the smallest block that can fit the requested size.
|
||||
For each deallocation request, the allocator merges the block with its neighbors if they are free.
|
||||
"""
|
||||
|
||||
def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16):
|
||||
self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length()
|
||||
self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)]
|
||||
self.lv1_entries:list[int] = [0] * len(self.storage)
|
||||
|
||||
# self.blocks is more like a linked list, where each entry is a contiguous block.
|
||||
self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free
|
||||
self._insert_block(0, size)
|
||||
|
||||
@functools.cache
|
||||
def lv1(self, size): return size.bit_length()
|
||||
|
||||
@functools.cache
|
||||
def lv2(self, size): return (size - (1 << (size.bit_length() - 1))) // (1 << max(0, size.bit_length() - self.l2_cnt))
|
||||
|
||||
def _insert_block(self, start:int, size:int, prev:int|None=None):
|
||||
if prev is None: prev = self.blocks[start][2]
|
||||
self.storage[self.lv1(size)][self.lv2(size)].append(start)
|
||||
self.lv1_entries[self.lv1(size)] += 1
|
||||
self.blocks[start] = (size, start + size, prev, True)
|
||||
return self
|
||||
|
||||
def _remove_block(self, start:int, size:int, prev:int|None=None):
|
||||
if prev is None: prev = self.blocks[start][2]
|
||||
self.storage[self.lv1(size)][self.lv2(size)].remove(start)
|
||||
self.lv1_entries[self.lv1(size)] -= 1
|
||||
self.blocks[start] = (size, start + size, prev, False)
|
||||
return self
|
||||
|
||||
def _split_block(self, start:int, size:int, new_size:int):
|
||||
nxt = self.blocks[start][1]
|
||||
assert self.blocks[start][3], "block must be free"
|
||||
self._remove_block(start, size)._insert_block(start, new_size)._insert_block(start + new_size, size - new_size, prev=start)
|
||||
if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start + new_size, self.blocks[nxt][3])
|
||||
return self
|
||||
|
||||
def _merge_right(self, start:int):
|
||||
size, nxt, _, is_free = self.blocks[start]
|
||||
assert is_free, "block must be free"
|
||||
|
||||
while is_free and nxt in self.blocks:
|
||||
if (blk:=self.blocks[nxt])[3] is False: break
|
||||
self._remove_block(start, size)._remove_block(nxt, blk[0])._insert_block(start, size:=size + blk[0])
|
||||
assert self.blocks[start][1] == blk[1]
|
||||
_, nxt, _, _ = self.blocks.pop(nxt)
|
||||
|
||||
if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start, self.blocks[nxt][3])
|
||||
|
||||
def _merge_block(self, start:int):
|
||||
# Go left while blocks are free. Then merge all them right.
|
||||
while (x:=self.blocks[start][2]) is not None and self.blocks[x][3] is True: start = x
|
||||
self._merge_right(start)
|
||||
|
||||
def alloc(self, req_size:int, align:int=1) -> int:
|
||||
req_size = max(self.block_size, req_size) # at least block size.
|
||||
size = max(self.block_size, req_size + align - 1)
|
||||
|
||||
# Round up the allocation size to the next bucket, so any entry there can fit the requested size.
|
||||
size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
|
||||
|
||||
# Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
|
||||
for l1 in range(self.lv1(size), len(self.storage)):
|
||||
if self.lv1_entries[l1] == 0: continue
|
||||
for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
|
||||
if len(self.storage[l1][l2]) > 0:
|
||||
nsize = self.blocks[self.storage[l1][l2][0]][0]
|
||||
assert nsize >= size, "block must be larger"
|
||||
|
||||
# Block start address.
|
||||
start = self.storage[l1][l2][0]
|
||||
|
||||
# If request contains alignment, split the block into two parts.
|
||||
if (new_start:=round_up(start, align)) != start:
|
||||
self._split_block(start, nsize, new_start - start)
|
||||
start, nsize = new_start, self.blocks[new_start][0]
|
||||
|
||||
# If the block is larger than the requested size, split it into two parts.
|
||||
if nsize > req_size: self._split_block(start, nsize, req_size)
|
||||
self._remove_block(start, req_size) # Mark the block as allocated.
|
||||
return start + self.base
|
||||
raise MemoryError(f"Can't allocate {req_size} bytes")
|
||||
|
||||
def free(self, start:int):
|
||||
self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)
|
||||
|
||||
# Memory Managment
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class VirtMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
|
||||
|
||||
class PageTableTraverseContext:
|
||||
def __init__(self, dev, pt, vaddr, create_pts=False, free_pts=False, boot=False):
|
||||
self.dev, self.vaddr, self.create_pts, self.free_pts, self.boot = dev, vaddr - dev.mm.va_base, create_pts, free_pts, boot
|
||||
self.pt_stack:list[tuple[Any, int, int]] = [(pt, self._pt_pte_idx(pt, vaddr), self._pt_pte_size(pt))]
|
||||
|
||||
def _pt_pte_cnt(self, lv): return self.dev.mm.pte_cnt[lv]
|
||||
def _pt_pte_size(self, pt): return self.dev.mm.pte_covers[pt.lv]
|
||||
def _pt_pte_idx(self, pt, va): return (va // self._pt_pte_size(pt)) % self._pt_pte_cnt(pt.lv)
|
||||
|
||||
def level_down(self):
|
||||
pt, pte_idx, _ = self.pt_stack[-1]
|
||||
|
||||
if not pt.valid(pte_idx):
|
||||
assert self.create_pts, "Not allowed to create new page table"
|
||||
pt.set_entry(pte_idx, self.dev.mm.palloc(0x1000, zero=True, boot=self.boot), table=True, valid=True)
|
||||
|
||||
assert not pt.is_pte(pte_idx), f"Must be table pt={pt.paddr:#x}, {pt.lv=} {pte_idx=} {pt.read_fields(pte_idx)}"
|
||||
child_page_table = self.dev.mm.pt_t(self.dev, pt.address(pte_idx), lv=pt.lv+1)
|
||||
|
||||
self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
|
||||
return self.pt_stack[-1]
|
||||
|
||||
def _try_free_pt(self) -> bool:
|
||||
pt, _, _ = self.pt_stack[-1]
|
||||
if self.free_pts and pt != self.dev.mm.root_page_table and all(not pt.valid(i) for i in range(self._pt_pte_cnt(self.pt_stack[-1][0].lv))):
|
||||
self.dev.mm.pfree(pt.paddr)
|
||||
parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
|
||||
parent_pt.set_entry(parent_pte_idx, 0x0, valid=False)
|
||||
return True
|
||||
return False
|
||||
|
||||
def level_up(self):
|
||||
while self._try_free_pt() or self.pt_stack[-1][1] == self._pt_pte_cnt(self.pt_stack[-1][0].lv):
|
||||
pt, pt_cnt, _ = self.pt_stack.pop()
|
||||
if pt_cnt == self._pt_pte_cnt(pt.lv): self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2])
|
||||
|
||||
def next(self, size:int, off=0):
|
||||
while size > 0:
|
||||
pt, pte_idx, pte_covers = self.pt_stack[-1]
|
||||
if self.create_pts:
|
||||
while pt.lv < self.dev.mm.first_page_lv or pte_covers > size or self.vaddr & (pte_covers-1) != 0: pt, pte_idx, pte_covers = self.level_down()
|
||||
else:
|
||||
while not pt.is_pte(pte_idx): pt, pte_idx, pte_covers = self.level_down()
|
||||
|
||||
entries = min(size // pte_covers, self._pt_pte_cnt(pt.lv) - pte_idx)
|
||||
assert entries > 0, f"Invalid entries {size=:#x}, {pte_covers=:#x}"
|
||||
yield off, pt, pte_idx, entries, pte_covers
|
||||
|
||||
size, off, self.vaddr = size - entries * pte_covers, off + entries * pte_covers, self.vaddr + entries * pte_covers
|
||||
self.pt_stack[-1] = (pt, pte_idx + entries, pte_covers)
|
||||
self.level_up()
|
||||
|
||||
class MemoryManager:
|
||||
va_allocator: ClassVar[TLSFAllocator|None] = None
|
||||
|
||||
def __init__(self, dev, vram_size:int, boot_size:int, pt_t, pte_cnt:list[int], pte_covers:list[int], first_lv:int, first_page_lv:int, va_base:int):
|
||||
self.dev, self.vram_size, self.va_base = dev, vram_size, va_base
|
||||
self.pt_t, self.pte_cnt, self.pte_covers, self.first_page_lv = pt_t, pte_cnt, pte_covers, first_page_lv
|
||||
|
||||
self.boot_allocator = TLSFAllocator(boot_size, base=0) # per device
|
||||
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20), base=self.boot_allocator.size) # per device
|
||||
self.root_page_table = pt_t(self.dev, self.palloc(0x1000, zero=not self.dev.smi_dev, boot=True), lv=first_lv)
|
||||
|
||||
def _frag_size(self, va, sz, must_cover=True):
|
||||
"""
|
||||
Calculate the tlb fragment size for a given virtual address and size.
|
||||
If must_cover is True, the fragment size must cover the size, otherwise the biggest fragment size that fits the size is returned.
|
||||
Fragment 0 is 4KB, 1 is 8KB and so on.
|
||||
"""
|
||||
va_pwr2_div, sz_pwr2_div, sz_pwr2_max = va & -(va) if va > 0 else (1 << 63), sz & -(sz), (1 << (sz.bit_length() - 1))
|
||||
return (min(va_pwr2_div, sz_pwr2_div) if must_cover else min(va_pwr2_div, sz_pwr2_max)).bit_length() - 1 - 12
|
||||
|
||||
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False, boot=False) -> VirtMapping:
|
||||
if getenv("MM_DEBUG", 0): print(f"mm {self.dev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
|
||||
|
||||
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
|
||||
|
||||
ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, create_pts=True, boot=boot)
|
||||
for paddr, psize in paddrs:
|
||||
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize):
|
||||
for pte_off in range(pte_cnt):
|
||||
assert not pt.valid(pte_idx + pte_off), f"PTE already mapped: {pt.entry(pte_idx + pte_off):#x}"
|
||||
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped,
|
||||
frag=self._frag_size(ctx.vaddr+off, pte_cnt * pte_covers), valid=True)
|
||||
|
||||
self.on_range_mapped()
|
||||
return VirtMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped)
|
||||
|
||||
def unmap_range(self, vaddr:int, size:int):
|
||||
if getenv("MM_DEBUG", 0): print(f"mm {self.dev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})")
|
||||
|
||||
ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, free_pts=True)
|
||||
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):
|
||||
for pte_id in range(pte_idx, pte_idx + pte_cnt):
|
||||
assert pt.valid(pte_id), f"PTE not mapped: {pt.entry(pte_id):#x}"
|
||||
pt.set_entry(pte_id, paddr=0x0, valid=False)
|
||||
|
||||
def on_range_mapped(self): pass
|
||||
|
||||
@classmethod
|
||||
def alloc_vaddr(cls, size:int, align=0x1000) -> int:
|
||||
assert cls.va_allocator is not None, "must be set it"
|
||||
return cls.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
|
||||
|
||||
def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> VirtMapping:
|
||||
# Alloc physical memory and map it to the virtual address
|
||||
va = self.alloc_vaddr(size:=round_up(size, 0x1000), align)
|
||||
|
||||
if contiguous: paddrs = [(self.palloc(size, zero=True), size)]
|
||||
else:
|
||||
# Traverse the PT to find the largest contiguous sizes we need to allocate. Try to allocate the longest segment to reduce TLB pressure.
|
||||
paddrs = []
|
||||
ctx = PageTableTraverseContext(self.dev, self.root_page_table, va, create_pts=True)
|
||||
for off, _, _, seg_cnt, seg_size in ctx.next(size):
|
||||
rem_len = seg_cnt * seg_size
|
||||
while rem_len > 0:
|
||||
# Try to allocate as long segment (power of 2) as possible
|
||||
cont_seg_sz, paddr = 1 << (self._frag_size(ctx.vaddr+off, rem_len) + 12), None
|
||||
while cont_seg_sz >= 0x1000:
|
||||
try: paddr = self.palloc(cont_seg_sz, zero=False)
|
||||
except MemoryError: cont_seg_sz //= 2
|
||||
else: break
|
||||
|
||||
if paddr is not None: paddrs += [(paddr, cont_seg_sz)]
|
||||
else:
|
||||
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
|
||||
raise MemoryError(f"Failed to allocate a contiguous page. (allocation size={size:#x})")
|
||||
rem_len, off = rem_len - cont_seg_sz, off + cont_seg_sz
|
||||
|
||||
return self.map_range(va, size, paddrs, uncached=uncached)
|
||||
|
||||
def vfree(self, vm:VirtMapping):
|
||||
assert self.va_allocator is not None, "must be set it"
|
||||
self.unmap_range(vm.va_addr, vm.size)
|
||||
self.va_allocator.free(vm.va_addr)
|
||||
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
|
||||
|
||||
def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int:
|
||||
assert self.dev.is_booting == boot, "During booting, only boot memory can be allocated"
|
||||
paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align)
|
||||
if zero: self.dev.vram[paddr:paddr+size] = bytes(size)
|
||||
return paddr
|
||||
|
||||
def pfree(self, paddr:int): self.pa_allocator.free(paddr)
|
||||
Reference in New Issue
Block a user