am: mi350 support (#13733)

This commit is contained in:
nimlgen
2025-12-17 14:57:21 +03:00
committed by GitHub
parent 5151a341b3
commit 3eecb4f123
5 changed files with 31 additions and 23 deletions

View File

@@ -2,6 +2,10 @@ import re, ctypes, sys, importlib
from tinygrad.helpers import getenv
from tinygrad.runtime.support.am.amdev import AMDev, AMRegister
class GFXFake:
def __init__(self): self.xccs = 8
class AMDFake(AMDev):
def __init__(self, pci_dev, dma_regions=None):
self.pci_dev, self.devfmt, self.dma_regions = pci_dev, pci_dev.pcibus, dma_regions
@@ -9,6 +13,8 @@ class AMDFake(AMDev):
self._run_discovery()
self._build_regs()
self.gfx = GFXFake()
amdev = importlib.import_module("tinygrad.runtime.support.am.amdev")
amdev.AMDev = AMDFake
from tinygrad.runtime.ops_amd import PCIIface

View File

@@ -796,7 +796,7 @@ class PCIIface(PCIIfaceBase):
gpus:ClassVar[list[str]] = []
def __init__(self, dev, dev_id):
super().__init__(dev, dev_id, vendor=0x1002, devices=[(0xffff, [0x74a1, 0x744c, 0x7480, 0x7550, 0x7590])], bars=[0, 2, 5], vram_bar=0,
super().__init__(dev, dev_id, vendor=0x1002, devices=[(0xffff, [0x74a1, 0x744c, 0x7480, 0x7550, 0x7590, 0x75a0])], bars=[0, 2, 5], vram_bar=0,
va_start=AMMemoryManager.va_allocator.base, va_size=AMMemoryManager.va_allocator.size)
self._setup_adev(self.pci_dev)
self.pci_dev.write_config(pci.PCI_COMMAND, self.pci_dev.read_config(pci.PCI_COMMAND, 2) | pci.PCI_COMMAND_MASTER, 2)

View File

@@ -42,14 +42,15 @@ class AMFirmware:
self.descs: list[tuple[list[int], memoryview]] = []
# SMU firmware
blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", versioned_header="struct_smc_firmware_header")
if self.adev.ip_ver[am.GC_HWIP] >= (11,0,0):
self.smu_psp_desc = self.desc(blob, hdr.v1_0.header.ucode_array_offset_bytes, hdr.v1_0.header.ucode_size_bytes, am.GFX_FW_TYPE_SMU)
else:
p2stables = (am.struct_smc_soft_pptable_entry * hdr.pptable_count).from_buffer(blob[hdr.pptable_entry_offset:])
for p2stable in p2stables:
if p2stable.id == (__P2S_TABLE_ID_X:=0x50325358):
self.descs += [self.desc(blob, p2stable.ppt_offset_bytes, p2stable.ppt_size_bytes, am.GFX_FW_TYPE_P2S_TABLE)]
if adev.ip_ver[am.MP1_HWIP] != (13,0,12):
blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", versioned_header="struct_smc_firmware_header")
if self.adev.ip_ver[am.GC_HWIP] >= (11,0,0):
self.smu_psp_desc = self.desc(blob, hdr.v1_0.header.ucode_array_offset_bytes, hdr.v1_0.header.ucode_size_bytes, am.GFX_FW_TYPE_SMU)
else:
p2stables = (am.struct_smc_soft_pptable_entry * hdr.pptable_count).from_buffer(blob[hdr.pptable_entry_offset:])
for p2stable in p2stables:
if p2stable.id == (__P2S_TABLE_ID_X:=0x50325358):
self.descs += [self.desc(blob, p2stable.ppt_offset_bytes, p2stable.ppt_size_bytes, am.GFX_FW_TYPE_P2S_TABLE)]
# SDMA firmware
blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", versioned_header="struct_sdma_firmware_header")
@@ -300,7 +301,7 @@ class AMDev(PCIDevImplBase):
def _build_regs(self):
mods = [("mp", am.MP0_HWIP), ("hdp", am.HDP_HWIP), ("gc", am.GC_HWIP), ("mmhub", am.MMHUB_HWIP), ("osssys", am.OSSSYS_HWIP),
("nbio" if self.ip_ver[am.GC_HWIP] < (12,0,0) else "nbif", am.NBIO_HWIP)]
if self.ip_ver[am.SDMA0_HWIP] == (4,4,2): mods += [("sdma", am.SDMA0_HWIP)]
if self.ip_ver[am.SDMA0_HWIP] in {(4,4,2), (4,4,4)}: mods += [("sdma", am.SDMA0_HWIP)]
for prefix, hwip in mods:
self.__dict__.update(import_asic_regs(prefix, self.ip_ver[hwip], cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[hwip])))

View File

@@ -15,7 +15,7 @@ class AM_SOC(AM_IP):
def init_sw(self): self.module = import_soc(self.adev.ip_ver[am.GC_HWIP])
def init_hw(self):
if self.adev.ip_ver[am.NBIO_HWIP] == (7,9,0):
if self.adev.ip_ver[am.NBIO_HWIP] in {(7,9,0), (7,9,1)}:
self.adev.regXCC_DOORBELL_FENCE.write(0x0)
self.adev.regBIFC_GFX_INT_MONITOR_MASK.write(0x7ff)
self.adev.regBIFC_DOORBELL_ACCESS_EN_PF.write(0xfffff)
@@ -29,7 +29,7 @@ class AM_SOC(AM_IP):
val = reg.encode(**{f"s2a_doorbell_port{port}_enable":1, f"s2a_doorbell_port{port}_awid":awid, f"s2a_doorbell_port{port}_range_size":size,
f"s2a_doorbell_port{port}_awaddr_31_28_value":awaddr_31_28_value, f"s2a_doorbell_port{port}_range_offset":offset})
if self.adev.ip_ver[am.NBIO_HWIP] == (7,9,0): self.adev.indirect_wreg_pcie(reg.addr[0], val)
if self.adev.ip_ver[am.NBIO_HWIP] in {(7,9,0), (7,9,1)}: self.adev.indirect_wreg_pcie(reg.addr[0], val)
else: reg.write(val)
class AM_GMC(AM_IP):
@@ -178,7 +178,7 @@ class AM_SMU(AM_IP):
def mode1_reset(self):
if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset")
if self.adev.ip_ver[am.MP0_HWIP] >= (14,0,0): self._send_msg(__DEBUGSMC_MSG_Mode1Reset:=2, 0, debug=True)
elif self.adev.ip_ver[am.MP0_HWIP] == (13,0,6): self._send_msg(self.smu_mod.PPSMC_MSG_GfxDriverReset, 1)
elif self.adev.ip_ver[am.MP0_HWIP] in {(13,0,6), (13,0,12)}: self._send_msg(self.smu_mod.PPSMC_MSG_GfxDriverReset, 1)
else: self._send_msg(self.smu_mod.PPSMC_MSG_Mode1Reset, 0)
time.sleep(0.5) # 500ms
@@ -188,7 +188,7 @@ class AM_SMU(AM_IP):
def read_metrics(self): return self.read_table(self.smu_mod.SmuMetricsExternal_t, self.smu_mod.TABLE_SMU_METRICS)
def set_clocks(self, level):
if self.adev.ip_ver[am.MP0_HWIP] == (13,0,6): return # TODO
if self.adev.ip_ver[am.MP0_HWIP] in {(13,0,6), (13,0,12)}: return # TODO
if not hasattr(self, 'clcks'):
self.clcks = {}
@@ -230,13 +230,13 @@ class AM_GFX(AM_IP):
for xcc in range(self.xccs): self.adev.regRLC_SPM_MC_CNTL.write(0xf, inst=xcc)
if self.adev.ip_ver[am.NBIO_HWIP] != (7,9,0):
if self.adev.ip_ver[am.NBIO_HWIP][:2] != (7,9):
self.adev.soc.doorbell_enable(port=0, awid=0x3, awaddr_31_28_value=0x3)
self.adev.soc.doorbell_enable(port=3, awid=0x6, awaddr_31_28_value=0x3)
for xcc in range(self.xccs):
if self.adev.ip_ver[am.GC_HWIP] == (9,4,3):
self.adev.regGB_ADDR_CONFIG.write(0x2a114042, inst=xcc) # Golden value for mi300
if self.adev.ip_ver[am.GC_HWIP] in {(9,4,3), (9,5,0)}:
self.adev.regGB_ADDR_CONFIG.write(0x2a114042, inst=xcc) # Golden value for mi300/mi350
self.adev.regTCP_UTCL1_CNTL2.update(spare=1, inst=xcc)
self.adev.regGRBM_CNTL.update(read_timeout=0xff, inst=xcc)
@@ -380,7 +380,7 @@ class AM_IH(AM_IP):
for _, rwptr_vm, suf, ring_id in self.rings:
self.adev.reg(f"regIH_RB_CNTL{suf}").update(rb_enable=1, **({'enable_intr': 1} if ring_id == 0 else {}))
if self.adev.ip_ver[am.NBIO_HWIP] != (7,9,0):
if self.adev.ip_ver[am.NBIO_HWIP][:2] != (7,9):
self.adev.soc.doorbell_enable(port=1, awid=0x0, awaddr_31_28_value=0x0, offset=am.AMDGPU_NAVI10_DOORBELL_IH*2, size=2)
def interrupt_handler(self):
@@ -410,14 +410,14 @@ class AM_SDMA(AM_IP):
self.adev.reg(f"regSDMA{pipe}_CNTL").update(ctxempty_int_enable=1, trap_enable=1,
**({'utc_l1_enable':1} if self.adev.ip_ver[am.SDMA0_HWIP] <= (5,2,0) else {}))
if self.adev.ip_ver[am.NBIO_HWIP] == (7,9,0):
if self.adev.ip_ver[am.NBIO_HWIP] in {(7,9,0), (7,9,1)}:
self.adev.regDOORBELL0_CTRL_ENTRY_1.write(bif_doorbell1_range_offset_entry=am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0*2,
bif_doorbell1_range_size_entry=4)
self.adev.soc.doorbell_enable(port=2, awid=0xe, awaddr_31_28_value=0x1, offset=0xe, size=4)
else: self.adev.soc.doorbell_enable(port=2, awid=0xe, awaddr_31_28_value=0x3, offset=am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0*2, size=4)
def fini_hw(self):
reg, inst = ("regSDMA_GFX", 0) if self.adev.ip_ver[am.SDMA0_HWIP] == (4,4,2) else ("regSDMA0_QUEUE0", 0)
reg, inst = ("regSDMA_GFX", 0) if self.adev.ip_ver[am.SDMA0_HWIP][:2] == (4,4) else ("regSDMA0_QUEUE0", 0)
self.adev.reg(f"{reg}_RB_CNTL").update(rb_enable=0, inst=inst)
self.adev.reg(f"{reg}_IB_CNTL").update(ib_enable=0, inst=inst)
@@ -428,7 +428,7 @@ class AM_SDMA(AM_IP):
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int) -> int:
# Setup the ring
reg, inst = ("regSDMA_GFX", pipe*4+queue) if self.adev.ip_ver[am.SDMA0_HWIP] == (4,4,2) else (f"regSDMA{pipe}_QUEUE{queue}", 0)
reg, inst = ("regSDMA_GFX", pipe*4+queue) if self.adev.ip_ver[am.SDMA0_HWIP][:2] == (4,4) else (f"regSDMA{pipe}_QUEUE{queue}", 0)
self.adev.reg(f"{reg}_MINOR_PTR_UPDATE").write(0x1, inst=inst)
self.adev.wreg_pair(f"{reg}_RB_RPTR", "", "_HI", 0, inst=inst)
@@ -439,7 +439,7 @@ class AM_SDMA(AM_IP):
self.adev.reg(f"{reg}_DOORBELL_OFFSET").update(offset=doorbell * 2, inst=inst)
self.adev.reg(f"{reg}_DOORBELL").update(enable=1, inst=inst)
self.adev.reg(f"{reg}_MINOR_PTR_UPDATE").write(0x0, inst=inst)
self.adev.reg(f"{reg}_RB_CNTL").write(**({f'{self.sdma_name.lower()}_wptr_poll_enable':1} if self.adev.ip_ver[am.SDMA0_HWIP] != (4,4,2) else {}),
self.adev.reg(f"{reg}_RB_CNTL").write(**({f'{self.sdma_name.lower()}_wptr_poll_enable':1} if self.adev.ip_ver[am.SDMA0_HWIP][:2]!=(4,4) else {}),
rb_vmid=0, rptr_writeback_enable=1, rptr_writeback_timer=4, rb_enable=1, rb_priv=1, rb_size=(ring_size//4).bit_length()-1, inst=inst)
self.adev.reg(f"{reg}_IB_CNTL").update(ib_enable=1, inst=inst)
return self.adev.reg(f"{reg}_RB_WPTR").read() | (self.adev.reg(f"{reg}_RB_WPTR_HI").read() << 32)

View File

@@ -38,8 +38,9 @@ def fixup_ip_version(ip:str, version:tuple[int, ...]) -> list[tuple[int, ...]]:
return version
if ip in ['nbio', 'nbif']: version = _apply_ovrd({(3,3): (2,3,0), (7,3): (7,2,0)})
elif ip in ['mp', 'smu']: version = _apply_ovrd({(14,0,3): (14,0,2)})
elif ip in ['mp', 'smu']: version = _apply_ovrd({(14,0,3): (14,0,2), (13,0,12): (13,0,6)})
elif ip in ['gc']: version = _apply_ovrd({(9,5,0): (9,4,3)})
elif ip in ['sdma']: version = _apply_ovrd({(4,4,4): (4,4,2)})
return [version, version[:2], version[:2]+(0,), version[:1]+(0, 0)]