am: support xgmi systems (#13659)

* am: support xgmi systems

* fake_am
This commit is contained in:
nimlgen
2025-12-12 18:55:45 +03:00
committed by GitHub
parent b4796e2d32
commit e36385e570
3 changed files with 24 additions and 8 deletions

View File

@@ -32,6 +32,8 @@ class FakeAM:
self.ip_ver = {am.GC_HWIP: (11, 0, 0)}
def paddr2cpu(self, paddr:int) -> int: return paddr + mv_address(self.vram)
def paddr2mc(self, paddr:int) -> int: return paddr
def paddr2xgmi(self, paddr:int) -> int: return paddr
def xgmi2paddr(self, xgmi_paddr:int) -> int: return xgmi_paddr
# * PTE format:
# * 63:59 reserved

View File

@@ -98,12 +98,15 @@ class AMPageTableEntry:
def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.lv, self.entries = adev, paddr, lv, adev.vram.view(paddr, 0x1000, fmt='Q')
def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
if not system: paddr = self.adev.paddr2xgmi(paddr)
assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}"
self.entries[entry_id] = self.adev.gmc.get_pte_flags(self.lv, table, frag, uncached, system, snooped, valid) | (paddr & 0x0000FFFFFFFFF000)
def entry(self, entry_id:int) -> int: return self.entries[entry_id]
def valid(self, entry_id:int) -> bool: return (self.entries[entry_id] & am.AMDGPU_PTE_VALID) != 0
def address(self, entry_id:int) -> int: return self.entries[entry_id] & 0x0000FFFFFFFFF000
def address(self, entry_id:int) -> int:
assert self.entries[entry_id] & am.AMDGPU_PTE_SYSTEM == 0, "should not be system address"
return self.adev.xgmi2paddr(self.entries[entry_id] & 0x0000FFFFFFFFF000)
def is_page(self, entry_id:int) -> bool: return self.lv == am.AMDGPU_VM_PTB or self.adev.gmc.is_pte_huge_page(self.entries[entry_id])
def supports_huge_page(self, paddr:int): return self.lv >= am.AMDGPU_VM_PDB2
@@ -190,6 +193,8 @@ class AMDev(PCIDevImplBase):
self.ih.interrupt_handler()
def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
def paddr2xgmi(self, paddr:int) -> int: return self.gmc.paddr_base + paddr
def xgmi2paddr(self, xgmi_paddr:int) -> int: return xgmi_paddr - self.gmc.paddr_base
def reg(self, reg:str) -> AMRegister: return self.__dict__[reg]

View File

@@ -28,8 +28,17 @@ class AM_GMC(AM_IP):
def init_sw(self):
self.vmhubs = len(self.adev.regs_offset[am.MMHUB_HWIP])
# XGMI (for supported systems)
xgmi_phys_id = self.adev.regMMMC_VM_XGMI_LFB_CNTL.read_bitfields()['pf_lfb_region'] if hasattr(self.adev, 'regMMMC_VM_XGMI_LFB_CNTL') else 0
xgmi_seg_sz = (self.adev.regMMMC_VM_XGMI_LFB_SIZE.read_bitfields()['pf_lfb_size'] << 24) if hasattr(self.adev, 'regMMMC_VM_XGMI_LFB_SIZE') else 0
self.paddr_base = xgmi_phys_id * xgmi_seg_sz
self.fb_base = (self.adev.regMMMC_VM_FB_LOCATION_BASE.read() & 0xFFFFFF) << 24
self.fb_end = (self.adev.regMMMC_VM_FB_LOCATION_TOP.read() & 0xFFFFFF) << 24
# Memory controller aperture
self.mc_base = (self.adev.regMMMC_VM_FB_LOCATION_BASE.read() & 0xFFFFFF) << 24
self.mc_base = self.fb_base + self.paddr_base
self.mc_end = self.mc_base + self.adev.mm.vram_size - 1
# VM aperture
@@ -39,8 +48,8 @@ class AM_GMC(AM_IP):
# GFX11/GFX12 has 44-bit address space
self.address_space_mask = (1 << 44) - 1
self.memscratch_paddr = self.adev.mm.palloc(0x1000, zero=False, boot=True)
self.dummy_page_paddr = self.adev.mm.palloc(0x1000, zero=False, boot=True)
self.memscratch_xgmi_paddr = self.adev.paddr2xgmi(self.adev.mm.palloc(0x1000, zero=False, boot=True))
self.dummy_page_xgmi_paddr = self.adev.paddr2xgmi(self.adev.mm.palloc(0x1000, zero=False, boot=True))
self.hub_initted = {"MM": False, "GC": False}
self.pf_status_reg = lambda ip: f"reg{ip}VM_L2_PROTECTION_FAULT_STATUS{'_LO32' if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0) else ''}"
@@ -71,7 +80,7 @@ class AM_GMC(AM_IP):
def enable_vm_addressing(self, page_table, ip:Literal["MM", "GC"], vmid, inst):
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_START_ADDR", "_LO32", "_HI32", self.vm_base >> 12, inst=inst)
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_END_ADDR", "_LO32", "_HI32", self.vm_end >> 12, inst=inst)
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.paddr | 1, inst=inst)
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", self.adev.paddr2xgmi(page_table.paddr) | 1, inst=inst)
self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1800000, pde0_protection_fault_enable_interrupt=1, pde0_protection_fault_enable_default=1,
dummy_page_protection_fault_enable_interrupt=1, dummy_page_protection_fault_enable_default=1,
range_protection_fault_enable_interrupt=1, range_protection_fault_enable_default=1,
@@ -90,8 +99,8 @@ class AM_GMC(AM_IP):
self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_LOW_ADDR").write(self.mc_base >> 18, inst=inst)
self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_HIGH_ADDR").write(self.mc_end >> 18, inst=inst)
self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_paddr >> 12, inst=inst)
self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_paddr >> 12, inst=inst)
self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_xgmi_paddr >> 12, inst=inst)
self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_xgmi_paddr >> 12, inst=inst)
self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_CNTL2").update(active_page_migration_pte_read_retry=1, inst=inst)
@@ -500,7 +509,7 @@ class AM_PSP(AM_IP):
def _tmr_load_cmd(self) -> am.struct_psp_gfx_cmd_resp:
cmd = am.struct_psp_gfx_cmd_resp(cmd_id=am.GFX_CMD_ID_SETUP_TMR)
cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.adev.paddr2mc(self.tmr_paddr))
cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_paddr)
cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.adev.paddr2xgmi(self.tmr_paddr))
cmd.cmd.cmd_setup_tmr.bitfield.virt_phy_addr = 1
cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_size
return self._ring_submit(cmd)