am: xgmi p2p (#13811)

* system: use addr space

* am: xgmi

* fix

* ugh
This commit is contained in:
nimlgen
2025-12-23 20:11:38 +03:00
committed by GitHub
parent 6439a515be
commit 90b217896f
7 changed files with 47 additions and 38 deletions

View File

@@ -2,7 +2,7 @@ import unittest
from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableEntry
from tinygrad.runtime.support.am.ip import AM_GMC
from tinygrad.runtime.support.hcq import MMIOInterface
from tinygrad.runtime.support.memory import PageTableTraverseContext
from tinygrad.runtime.support.memory import PageTableTraverseContext, AddrSpace
from tinygrad.runtime.autogen.am import am
from tinygrad.helpers import mv_address
@@ -70,7 +70,7 @@ class TestAMPageTable(unittest.TestCase):
for va,sz in [(0x10000, 0x3000), (0x11000, 0x300000), (0x10000, 0x2000), (0x11000, 0x5000),
(0x2000000, 0x2000), (0x4000000, 0x4000000), (0x38000, 0x303000), (0x8000, 0x1000)]:
mm.map_range(vaddr=helper_va(va), size=sz, paddrs=[(va, sz)])
mm.map_range(vaddr=helper_va(va), size=sz, paddrs=[(va, sz)], aspace=AddrSpace.PHYS)
ctx = PageTableTraverseContext(self.d[0], mm.root_page_table, helper_va(va))
results = list(ctx.next(sz))
@@ -102,8 +102,8 @@ class TestAMPageTable(unittest.TestCase):
mm0 = self.d[0].mm
for (va1,sz1),(va2,sz2) in [((0x10000, (0x1000)), (0x11000, (2 << 20)))]:
mm0.map_range(vaddr=helper_va(va1), size=sz1, paddrs=[(va1, sz1)])
mm0.map_range(vaddr=helper_va(va2), size=sz2, paddrs=[(va2, sz2)])
mm0.map_range(vaddr=helper_va(va1), size=sz1, paddrs=[(va1, sz1)], aspace=AddrSpace.PHYS)
mm0.map_range(vaddr=helper_va(va2), size=sz2, paddrs=[(va2, sz2)], aspace=AddrSpace.PHYS)
mm0.unmap_range(helper_va(va2), sz2)
mm0.unmap_range(helper_va(va1), sz1)
@@ -112,24 +112,24 @@ class TestAMPageTable(unittest.TestCase):
for va,sz in [(0x10000, 0x3000), (0x1000000, 0x1000000), (0x12000, 0x4000)]:
exteranl_va = helper_va(va)
mm0.map_range(vaddr=exteranl_va, size=sz, paddrs=[(va, sz)])
mm0.map_range(vaddr=exteranl_va, size=sz, paddrs=[(va, sz)], aspace=AddrSpace.PHYS)
with self.assertRaises(AssertionError):
mm0.map_range(vaddr=exteranl_va, size=0x1000, paddrs=[(va, sz)])
mm0.map_range(vaddr=exteranl_va, size=0x1000, paddrs=[(va, sz)], aspace=AddrSpace.PHYS)
with self.assertRaises(AssertionError):
mm0.map_range(vaddr=exteranl_va, size=0x100000, paddrs=[(va, sz)])
mm0.map_range(vaddr=exteranl_va, size=0x100000, paddrs=[(va, sz)], aspace=AddrSpace.PHYS)
with self.assertRaises(AssertionError):
mm0.map_range(vaddr=exteranl_va + 0x1000, size=0x1000, paddrs=[(va, sz)])
mm0.map_range(vaddr=exteranl_va + 0x1000, size=0x1000, paddrs=[(va, sz)], aspace=AddrSpace.PHYS)
with self.assertRaises(AssertionError):
mm0.map_range(vaddr=exteranl_va + 0x2000, size=0x100000, paddrs=[(va, sz)])
mm0.map_range(vaddr=exteranl_va + 0x2000, size=0x100000, paddrs=[(va, sz)], aspace=AddrSpace.PHYS)
mm0.unmap_range(vaddr=exteranl_va, size=sz)
# Finally can map and check paddrs
mm0.map_range(vaddr=exteranl_va + 0x2000, size=0x100000, paddrs=[(0xdead0000, 0x1000), (0xdead1000, 0xff000)])
mm0.map_range(vaddr=exteranl_va + 0x2000, size=0x100000, paddrs=[(0xdead0000, 0x1000), (0xdead1000, 0xff000)], aspace=AddrSpace.PHYS)
ctx = PageTableTraverseContext(self.d[0], mm0.root_page_table, exteranl_va + 0x2000)
for tup in ctx.next(0x100000):
@@ -147,13 +147,13 @@ class TestAMPageTable(unittest.TestCase):
with self.assertRaises(AssertionError):
mm0.unmap_range(helper_va(0x10000), 0x3000)
mm0.map_range(helper_va(0x10000), 0x3000, paddrs=[(0x10000, 0x3000)])
mm0.map_range(helper_va(0x10000), 0x3000, paddrs=[(0x10000, 0x3000)], aspace=AddrSpace.PHYS)
mm0.unmap_range(helper_va(0x10000), 0x3000)
with self.assertRaises(AssertionError):
mm0.unmap_range(helper_va(0x10000), 0x3000)
mm0.map_range(helper_va(0x10000), 0x3000, paddrs=[(0x10000, 0x3000)])
mm0.map_range(helper_va(0x10000), 0x3000, paddrs=[(0x10000, 0x3000)], aspace=AddrSpace.PHYS)
mm0.unmap_range(helper_va(0x10000), 0x3000)
with self.assertRaises(AssertionError):
@@ -164,16 +164,16 @@ class TestAMPageTable(unittest.TestCase):
# offset from start
for off in [0, 0x3000, 0x10000]:
mm0.map_range(helper_va(0x1000000) + off, (2 << 20) - off, paddrs=[(0x10000, 0x1000)] * (512 - off // 0x1000))
mm0.map_range(helper_va(0x1000000) + off, (2 << 20) - off, paddrs=[(0x10000, 0x1000)] * (512 - off // 0x1000), aspace=AddrSpace.PHYS)
mm0.unmap_range(helper_va(0x1000000) + off, (2 << 20) - off)
mm0.map_range(helper_va(0x1000000), 2 << 20, paddrs=[(0x10000, 2 << 20)])
mm0.map_range(helper_va(0x1000000), 2 << 20, paddrs=[(0x10000, 2 << 20)], aspace=AddrSpace.PHYS)
mm0.unmap_range(helper_va(0x1000000), 2 << 20)
# offset from end
for off in [0x1000, 0x20000]:
mm0.map_range(helper_va(0x1000000), (2 << 20) - off, paddrs=[(0x10000, 0x1000)] * (512 - off // 0x1000))
mm0.map_range(helper_va(0x1000000), (2 << 20) - off, paddrs=[(0x10000, 0x1000)] * (512 - off // 0x1000), aspace=AddrSpace.PHYS)
mm0.unmap_range(helper_va(0x1000000), (2 << 20) - off)
mm0.map_range(helper_va(0x1000000), 2 << 20, paddrs=[(0x10000, 2 << 20)])
mm0.map_range(helper_va(0x1000000), 2 << 20, paddrs=[(0x10000, 2 << 20)], aspace=AddrSpace.PHYS)
mm0.unmap_range(helper_va(0x1000000), 2 << 20)
def test_frag_size(self):

View File

@@ -18,6 +18,7 @@ from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager
from tinygrad.runtime.support.amd import AMDReg, AMDIP, import_module, import_soc, import_ip_offsets, import_pmc
from tinygrad.runtime.support.system import System, PCIIfaceBase, PCIAllocationMeta, PCIDevice, USBPCIDevice, MAP_FIXED, MAP_NORESERVE
from tinygrad.runtime.support.memory import AddrSpace
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
SQTT = ContextVar("SQTT", abs(VIZ.value)>=2)
@@ -859,7 +860,7 @@ class USBIface(PCIIface):
self.sys_buf, self.sys_next_off = self._dma_region(ctrl_addr=0xa000, sys_addr=0x820000, size=0x1000), 0x800
def _dma_region(self, ctrl_addr, sys_addr, size):
region = self.dev_impl.mm.map_range(vaddr:=self.dev_impl.mm.alloc_vaddr(size=size), size, [(sys_addr, size)], system=True, uncached=True)
region = self.dev_impl.mm.map_range(vaddr:=self.dev_impl.mm.alloc_vaddr(size=size), size, [(sys_addr, size)], aspace=AddrSpace.SYS, uncached=True)
return HCQBuffer(vaddr, size, meta=PCIAllocationMeta(region, has_cpu_mapping=False), view=self.pci_dev.dma_view(ctrl_addr, size), owner=self.dev)
def alloc(self, size:int, host=False, uncached=False, cpu_access=False, contiguous=False, **kwargs) -> HCQBuffer:

View File

@@ -4,7 +4,7 @@ from tinygrad.helpers import mv_address, getenv, DEBUG, fetch
from tinygrad.runtime.autogen.am import am
from tinygrad.runtime.support.hcq import MMIOInterface
from tinygrad.runtime.support.amd import AMDReg, import_module, import_asic_regs
from tinygrad.runtime.support.memory import TLSFAllocator, MemoryManager
from tinygrad.runtime.support.memory import TLSFAllocator, MemoryManager, AddrSpace
from tinygrad.runtime.support.system import PCIDevice, PCIDevImplBase
from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
@@ -120,10 +120,11 @@ class AMFirmware:
class AMPageTableEntry:
def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.lv, self.entries = adev, paddr, lv, adev.vram.view(paddr, 0x1000, fmt='Q')
def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
if not system: paddr = self.adev.paddr2xgmi(paddr)
def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, aspace=AddrSpace.PHYS, snooped=False, frag=0, valid=True):
is_sys = aspace is AddrSpace.SYS
if aspace is AddrSpace.PHYS: paddr = self.adev.paddr2xgmi(paddr)
assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}"
self.entries[entry_id] = self.adev.gmc.get_pte_flags(self.lv, table, frag, uncached, system, snooped, valid) | (paddr & 0x0000FFFFFFFFF000)
self.entries[entry_id] = self.adev.gmc.get_pte_flags(self.lv, table, frag, uncached, is_sys, snooped, valid) | (paddr & 0x0000FFFFFFFFF000)
def entry(self, entry_id:int) -> int: return self.entries[entry_id]
def valid(self, entry_id:int) -> bool: return (self.entries[entry_id] & am.AMDGPU_PTE_VALID) != 0

View File

@@ -3,6 +3,7 @@ from typing import Literal
from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG, wait_cond, pad_bytes
from tinygrad.runtime.autogen.am import am
from tinygrad.runtime.support.amd import import_soc
from tinygrad.runtime.support.memory import AddrSpace
class AM_IP:
def __init__(self, adev): self.adev = adev
@@ -468,7 +469,7 @@ class AM_PSP(AM_IP):
msg1_region = next((reg for reg in self.adev.dma_regions or [] if reg[1].nbytes >= (512 << 10)), None)
if msg1_region is not None:
self.msg1_addr, self.msg1_view = self.adev.mm.alloc_vaddr(size=msg1_region[1].nbytes, align=am.PSP_1_MEG), msg1_region[1]
self.adev.mm.map_range(self.msg1_addr, msg1_region[1].nbytes, [(msg1_region[0], msg1_region[1].nbytes)], system=True, uncached=True, boot=True)
self.adev.mm.map_range(self.msg1_addr, msg1_region[1].nbytes, [(msg1_region[0],msg1_region[1].nbytes)], AddrSpace.SYS, uncached=True, boot=True)
else:
self.msg1_paddr = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=False, boot=True)
self.msg1_addr, self.msg1_view = self.adev.paddr2mc(self.msg1_paddr), self.adev.vram.view(self.msg1_paddr, am.PSP_1_MEG, 'B')

View File

@@ -1,4 +1,4 @@
import collections, functools, dataclasses
import collections, functools, dataclasses, enum
from typing import Any, ClassVar
from tinygrad.helpers import round_up, getenv
@@ -107,8 +107,10 @@ class TLSFAllocator:
# Memory Managment
class AddrSpace(enum.Enum): PHYS = enum.auto(); SYS = enum.auto(); PEER = enum.auto() # noqa: E702
@dataclasses.dataclass(frozen=True)
class VirtMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
class VirtMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; aspace:AddrSpace; uncached:bool=False; snooped:bool=False # noqa: E702
class PageTableTraverseContext:
def __init__(self, dev, pt, vaddr, create_pts=False, free_pts=False, boot=False):
@@ -190,7 +192,7 @@ class MemoryManager:
ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, create_pts=True)
for _ in ctx.next(size, paddr=0): return [pt for pt, _, _ in ctx.pt_stack]
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False, boot=False) -> VirtMapping:
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], aspace:AddrSpace, uncached=False, snooped=False, boot=False) -> VirtMapping:
if getenv("MM_DEBUG", 0): print(f"mm {self.dev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
@@ -200,11 +202,11 @@ class MemoryManager:
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize, paddr=paddr):
for pte_off in range(pte_cnt):
assert not pt.valid(pte_idx + pte_off), f"PTE already mapped: {pt.entry(pte_idx + pte_off):#x}"
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped,
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, aspace=aspace, snooped=snooped,
frag=self._frag_size(ctx.vaddr+off, pte_cnt * pte_covers), valid=True)
self.on_range_mapped()
return VirtMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped)
return VirtMapping(vaddr, size, paddrs, aspace=aspace, uncached=uncached, snooped=snooped)
def unmap_range(self, vaddr:int, size:int):
if getenv("MM_DEBUG", 0): print(f"mm {self.dev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})")
@@ -243,7 +245,7 @@ class MemoryManager:
continue
rem_size -= self.palloc_ranges[nxt_range][0]
return self.map_range(va, size, paddrs, uncached=uncached)
return self.map_range(va, size, paddrs, aspace=AddrSpace.PHYS, uncached=uncached)
def vfree(self, vm:VirtMapping):
assert self.va_allocator is not None, "must be set it"

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import ctypes, time, functools, re, gzip, struct
from tinygrad.helpers import getenv, DEBUG, fetch, getbits
from tinygrad.runtime.support.memory import TLSFAllocator, MemoryManager
from tinygrad.runtime.support.memory import TLSFAllocator, MemoryManager, AddrSpace
from tinygrad.runtime.support.nv.ip import NV_FLCN, NV_FLCN_COT, NV_GSP
from tinygrad.runtime.support.system import System, PCIDevice, PCIDevImplBase
@@ -33,9 +33,9 @@ class NVPageTableEntry:
def _is_dual_pde(self) -> bool: return self.lv == self.nvdev.mm.level_cnt - 2
def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, aspace=AddrSpace.PHYS, snooped=False, frag=0, valid=True):
if not table:
x = self.nvdev.pte_t.encode(valid=valid, address_sys=paddr >> 12, aperture=2 if system else 0, kind=6,
x = self.nvdev.pte_t.encode(valid=valid, address_sys=paddr >> 12, aperture=2 if aspace is AddrSpace.SYS else 0, kind=6,
**({'pcf': int(uncached)} if self.nvdev.mmu_ver == 3 else {'vol': uncached}))
else:
pde = self.nvdev.dual_pde_t if self._is_dual_pde() else self.nvdev.pde_t

View File

@@ -3,7 +3,7 @@ from typing import cast, ClassVar
from tinygrad.helpers import round_up, getenv, OSX, temp, ceildiv
from tinygrad.runtime.autogen import libc, vfio, pci
from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuffer, hcq_filter_visible_devices
from tinygrad.runtime.support.memory import MemoryManager, VirtMapping
from tinygrad.runtime.support.memory import MemoryManager, VirtMapping, AddrSpace
from tinygrad.runtime.support.usb import ASM24Controller, USBMMIOInterface
MAP_FIXED, MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0x10, 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
@@ -262,7 +262,7 @@ class LNXPCIIfaceBase:
if should_use_sysmem:
vaddr = self.dev_impl.mm.alloc_vaddr(size:=round_up(size, mmap.PAGESIZE), align=mmap.PAGESIZE)
memview, paddrs = System.alloc_sysmem(size, vaddr=vaddr, contiguous=contiguous)
mapping = self.dev_impl.mm.map_range(vaddr, size, [(paddr, 0x1000) for paddr in paddrs], system=True, snooped=True, uncached=True)
mapping = self.dev_impl.mm.map_range(vaddr, size, [(paddr, 0x1000) for paddr in paddrs], aspace=AddrSpace.SYS, snooped=True, uncached=True)
return HCQBuffer(vaddr, size, meta=PCIAllocationMeta(mapping, has_cpu_mapping=True, hMemory=paddrs[0]), view=memview, owner=self.dev)
mapping = self.dev_impl.mm.valloc(size:=round_up(size, 0x1000), uncached=uncached, contiguous=cpu_access)
@@ -271,19 +271,23 @@ class LNXPCIIfaceBase:
def free(self, b:HCQBuffer):
for dev in b.mapped_devs[1:]: dev.iface.dev_impl.mm.unmap_range(b.va_addr, b.size)
if not b.meta.mapping.system: self.dev_impl.mm.vfree(b.meta.mapping)
if b.meta.mapping.aspace is AddrSpace.PHYS: self.dev_impl.mm.vfree(b.meta.mapping)
if b.owner == self.dev and b.meta.has_cpu_mapping and not OSX: FileIOInterface.munmap(b.va_addr, b.size)
def map(self, b:HCQBuffer):
if b.owner is not None and b.owner._is_cpu():
System.lock_memory(cast(int, b.va_addr), b.size)
paddrs, snooped, uncached = [(x, 0x1000) for x in System.system_paddrs(cast(int, b.va_addr), round_up(b.size, 0x1000))], True, True
paddrs, aspace = [(x, 0x1000) for x in System.system_paddrs(cast(int, b.va_addr), round_up(b.size, 0x1000))], AddrSpace.SYS
snooped, uncached = True, True
elif (ifa:=getattr(b.owner, "iface", None)) is not None and isinstance(ifa, LNXPCIIfaceBase):
paddrs = [(paddr if b.meta.mapping.system else (paddr + ifa.p2p_base_addr), size) for paddr,size in b.meta.mapping.paddrs]
snooped, uncached = b.meta.mapping.snooped, b.meta.mapping.uncached
snooped, uncached = True, b.meta.mapping.uncached
if b.meta.mapping.aspace is AddrSpace.SYS: paddrs, aspace = b.meta.mapping.paddrs, AddrSpace.SYS
elif hasattr(ifa.dev_impl, 'paddr2xgmi') and ifa.dev_impl.gmc.xgmi_seg_sz > 0:
paddrs, aspace = [(ifa.dev_impl.paddr2xgmi(p), sz) for p, sz in b.meta.mapping.paddrs], AddrSpace.PEER
else: paddrs, aspace = [(p + ifa.p2p_base_addr, sz) for p, sz in b.meta.mapping.paddrs], AddrSpace.SYS
else: raise RuntimeError(f"map failed: {b.owner} -> {self.dev}")
self.dev_impl.mm.map_range(cast(int, b.va_addr), round_up(b.size, 0x1000), paddrs, system=True, snooped=snooped, uncached=uncached)
self.dev_impl.mm.map_range(cast(int, b.va_addr), round_up(b.size, 0x1000), paddrs, aspace=aspace, snooped=snooped, uncached=uncached)
class APLPCIIfaceBase(LNXPCIIfaceBase):
def __init__(self, dev, dev_id, vendor, devices, bars, vram_bar, va_start, va_size):