mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
am: mi350 support (#13733)
This commit is contained in:
@@ -2,6 +2,10 @@ import re, ctypes, sys, importlib
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
from tinygrad.runtime.support.am.amdev import AMDev, AMRegister
|
||||
|
||||
class GFXFake:
|
||||
def __init__(self): self.xccs = 8
|
||||
|
||||
class AMDFake(AMDev):
|
||||
def __init__(self, pci_dev, dma_regions=None):
|
||||
self.pci_dev, self.devfmt, self.dma_regions = pci_dev, pci_dev.pcibus, dma_regions
|
||||
@@ -9,6 +13,8 @@ class AMDFake(AMDev):
|
||||
self._run_discovery()
|
||||
self._build_regs()
|
||||
|
||||
self.gfx = GFXFake()
|
||||
|
||||
amdev = importlib.import_module("tinygrad.runtime.support.am.amdev")
|
||||
amdev.AMDev = AMDFake
|
||||
from tinygrad.runtime.ops_amd import PCIIface
|
||||
|
||||
@@ -796,7 +796,7 @@ class PCIIface(PCIIfaceBase):
|
||||
gpus:ClassVar[list[str]] = []
|
||||
|
||||
def __init__(self, dev, dev_id):
|
||||
super().__init__(dev, dev_id, vendor=0x1002, devices=[(0xffff, [0x74a1, 0x744c, 0x7480, 0x7550, 0x7590])], bars=[0, 2, 5], vram_bar=0,
|
||||
super().__init__(dev, dev_id, vendor=0x1002, devices=[(0xffff, [0x74a1, 0x744c, 0x7480, 0x7550, 0x7590, 0x75a0])], bars=[0, 2, 5], vram_bar=0,
|
||||
va_start=AMMemoryManager.va_allocator.base, va_size=AMMemoryManager.va_allocator.size)
|
||||
self._setup_adev(self.pci_dev)
|
||||
self.pci_dev.write_config(pci.PCI_COMMAND, self.pci_dev.read_config(pci.PCI_COMMAND, 2) | pci.PCI_COMMAND_MASTER, 2)
|
||||
|
||||
@@ -42,14 +42,15 @@ class AMFirmware:
|
||||
self.descs: list[tuple[list[int], memoryview]] = []
|
||||
|
||||
# SMU firmware
|
||||
blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", versioned_header="struct_smc_firmware_header")
|
||||
if self.adev.ip_ver[am.GC_HWIP] >= (11,0,0):
|
||||
self.smu_psp_desc = self.desc(blob, hdr.v1_0.header.ucode_array_offset_bytes, hdr.v1_0.header.ucode_size_bytes, am.GFX_FW_TYPE_SMU)
|
||||
else:
|
||||
p2stables = (am.struct_smc_soft_pptable_entry * hdr.pptable_count).from_buffer(blob[hdr.pptable_entry_offset:])
|
||||
for p2stable in p2stables:
|
||||
if p2stable.id == (__P2S_TABLE_ID_X:=0x50325358):
|
||||
self.descs += [self.desc(blob, p2stable.ppt_offset_bytes, p2stable.ppt_size_bytes, am.GFX_FW_TYPE_P2S_TABLE)]
|
||||
if adev.ip_ver[am.MP1_HWIP] != (13,0,12):
|
||||
blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", versioned_header="struct_smc_firmware_header")
|
||||
if self.adev.ip_ver[am.GC_HWIP] >= (11,0,0):
|
||||
self.smu_psp_desc = self.desc(blob, hdr.v1_0.header.ucode_array_offset_bytes, hdr.v1_0.header.ucode_size_bytes, am.GFX_FW_TYPE_SMU)
|
||||
else:
|
||||
p2stables = (am.struct_smc_soft_pptable_entry * hdr.pptable_count).from_buffer(blob[hdr.pptable_entry_offset:])
|
||||
for p2stable in p2stables:
|
||||
if p2stable.id == (__P2S_TABLE_ID_X:=0x50325358):
|
||||
self.descs += [self.desc(blob, p2stable.ppt_offset_bytes, p2stable.ppt_size_bytes, am.GFX_FW_TYPE_P2S_TABLE)]
|
||||
|
||||
# SDMA firmware
|
||||
blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", versioned_header="struct_sdma_firmware_header")
|
||||
@@ -300,7 +301,7 @@ class AMDev(PCIDevImplBase):
|
||||
def _build_regs(self):
|
||||
mods = [("mp", am.MP0_HWIP), ("hdp", am.HDP_HWIP), ("gc", am.GC_HWIP), ("mmhub", am.MMHUB_HWIP), ("osssys", am.OSSSYS_HWIP),
|
||||
("nbio" if self.ip_ver[am.GC_HWIP] < (12,0,0) else "nbif", am.NBIO_HWIP)]
|
||||
if self.ip_ver[am.SDMA0_HWIP] == (4,4,2): mods += [("sdma", am.SDMA0_HWIP)]
|
||||
if self.ip_ver[am.SDMA0_HWIP] in {(4,4,2), (4,4,4)}: mods += [("sdma", am.SDMA0_HWIP)]
|
||||
|
||||
for prefix, hwip in mods:
|
||||
self.__dict__.update(import_asic_regs(prefix, self.ip_ver[hwip], cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[hwip])))
|
||||
|
||||
@@ -15,7 +15,7 @@ class AM_SOC(AM_IP):
|
||||
def init_sw(self): self.module = import_soc(self.adev.ip_ver[am.GC_HWIP])
|
||||
|
||||
def init_hw(self):
|
||||
if self.adev.ip_ver[am.NBIO_HWIP] == (7,9,0):
|
||||
if self.adev.ip_ver[am.NBIO_HWIP] in {(7,9,0), (7,9,1)}:
|
||||
self.adev.regXCC_DOORBELL_FENCE.write(0x0)
|
||||
self.adev.regBIFC_GFX_INT_MONITOR_MASK.write(0x7ff)
|
||||
self.adev.regBIFC_DOORBELL_ACCESS_EN_PF.write(0xfffff)
|
||||
@@ -29,7 +29,7 @@ class AM_SOC(AM_IP):
|
||||
val = reg.encode(**{f"s2a_doorbell_port{port}_enable":1, f"s2a_doorbell_port{port}_awid":awid, f"s2a_doorbell_port{port}_range_size":size,
|
||||
f"s2a_doorbell_port{port}_awaddr_31_28_value":awaddr_31_28_value, f"s2a_doorbell_port{port}_range_offset":offset})
|
||||
|
||||
if self.adev.ip_ver[am.NBIO_HWIP] == (7,9,0): self.adev.indirect_wreg_pcie(reg.addr[0], val)
|
||||
if self.adev.ip_ver[am.NBIO_HWIP] in {(7,9,0), (7,9,1)}: self.adev.indirect_wreg_pcie(reg.addr[0], val)
|
||||
else: reg.write(val)
|
||||
|
||||
class AM_GMC(AM_IP):
|
||||
@@ -178,7 +178,7 @@ class AM_SMU(AM_IP):
|
||||
def mode1_reset(self):
|
||||
if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset")
|
||||
if self.adev.ip_ver[am.MP0_HWIP] >= (14,0,0): self._send_msg(__DEBUGSMC_MSG_Mode1Reset:=2, 0, debug=True)
|
||||
elif self.adev.ip_ver[am.MP0_HWIP] == (13,0,6): self._send_msg(self.smu_mod.PPSMC_MSG_GfxDriverReset, 1)
|
||||
elif self.adev.ip_ver[am.MP0_HWIP] in {(13,0,6), (13,0,12)}: self._send_msg(self.smu_mod.PPSMC_MSG_GfxDriverReset, 1)
|
||||
else: self._send_msg(self.smu_mod.PPSMC_MSG_Mode1Reset, 0)
|
||||
time.sleep(0.5) # 500ms
|
||||
|
||||
@@ -188,7 +188,7 @@ class AM_SMU(AM_IP):
|
||||
def read_metrics(self): return self.read_table(self.smu_mod.SmuMetricsExternal_t, self.smu_mod.TABLE_SMU_METRICS)
|
||||
|
||||
def set_clocks(self, level):
|
||||
if self.adev.ip_ver[am.MP0_HWIP] == (13,0,6): return # TODO
|
||||
if self.adev.ip_ver[am.MP0_HWIP] in {(13,0,6), (13,0,12)}: return # TODO
|
||||
|
||||
if not hasattr(self, 'clcks'):
|
||||
self.clcks = {}
|
||||
@@ -230,13 +230,13 @@ class AM_GFX(AM_IP):
|
||||
|
||||
for xcc in range(self.xccs): self.adev.regRLC_SPM_MC_CNTL.write(0xf, inst=xcc)
|
||||
|
||||
if self.adev.ip_ver[am.NBIO_HWIP] != (7,9,0):
|
||||
if self.adev.ip_ver[am.NBIO_HWIP][:2] != (7,9):
|
||||
self.adev.soc.doorbell_enable(port=0, awid=0x3, awaddr_31_28_value=0x3)
|
||||
self.adev.soc.doorbell_enable(port=3, awid=0x6, awaddr_31_28_value=0x3)
|
||||
|
||||
for xcc in range(self.xccs):
|
||||
if self.adev.ip_ver[am.GC_HWIP] == (9,4,3):
|
||||
self.adev.regGB_ADDR_CONFIG.write(0x2a114042, inst=xcc) # Golden value for mi300
|
||||
if self.adev.ip_ver[am.GC_HWIP] in {(9,4,3), (9,5,0)}:
|
||||
self.adev.regGB_ADDR_CONFIG.write(0x2a114042, inst=xcc) # Golden value for mi300/mi350
|
||||
self.adev.regTCP_UTCL1_CNTL2.update(spare=1, inst=xcc)
|
||||
|
||||
self.adev.regGRBM_CNTL.update(read_timeout=0xff, inst=xcc)
|
||||
@@ -380,7 +380,7 @@ class AM_IH(AM_IP):
|
||||
for _, rwptr_vm, suf, ring_id in self.rings:
|
||||
self.adev.reg(f"regIH_RB_CNTL{suf}").update(rb_enable=1, **({'enable_intr': 1} if ring_id == 0 else {}))
|
||||
|
||||
if self.adev.ip_ver[am.NBIO_HWIP] != (7,9,0):
|
||||
if self.adev.ip_ver[am.NBIO_HWIP][:2] != (7,9):
|
||||
self.adev.soc.doorbell_enable(port=1, awid=0x0, awaddr_31_28_value=0x0, offset=am.AMDGPU_NAVI10_DOORBELL_IH*2, size=2)
|
||||
|
||||
def interrupt_handler(self):
|
||||
@@ -410,14 +410,14 @@ class AM_SDMA(AM_IP):
|
||||
self.adev.reg(f"regSDMA{pipe}_CNTL").update(ctxempty_int_enable=1, trap_enable=1,
|
||||
**({'utc_l1_enable':1} if self.adev.ip_ver[am.SDMA0_HWIP] <= (5,2,0) else {}))
|
||||
|
||||
if self.adev.ip_ver[am.NBIO_HWIP] == (7,9,0):
|
||||
if self.adev.ip_ver[am.NBIO_HWIP] in {(7,9,0), (7,9,1)}:
|
||||
self.adev.regDOORBELL0_CTRL_ENTRY_1.write(bif_doorbell1_range_offset_entry=am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0*2,
|
||||
bif_doorbell1_range_size_entry=4)
|
||||
self.adev.soc.doorbell_enable(port=2, awid=0xe, awaddr_31_28_value=0x1, offset=0xe, size=4)
|
||||
else: self.adev.soc.doorbell_enable(port=2, awid=0xe, awaddr_31_28_value=0x3, offset=am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0*2, size=4)
|
||||
|
||||
def fini_hw(self):
|
||||
reg, inst = ("regSDMA_GFX", 0) if self.adev.ip_ver[am.SDMA0_HWIP] == (4,4,2) else ("regSDMA0_QUEUE0", 0)
|
||||
reg, inst = ("regSDMA_GFX", 0) if self.adev.ip_ver[am.SDMA0_HWIP][:2] == (4,4) else ("regSDMA0_QUEUE0", 0)
|
||||
|
||||
self.adev.reg(f"{reg}_RB_CNTL").update(rb_enable=0, inst=inst)
|
||||
self.adev.reg(f"{reg}_IB_CNTL").update(ib_enable=0, inst=inst)
|
||||
@@ -428,7 +428,7 @@ class AM_SDMA(AM_IP):
|
||||
|
||||
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int) -> int:
|
||||
# Setup the ring
|
||||
reg, inst = ("regSDMA_GFX", pipe*4+queue) if self.adev.ip_ver[am.SDMA0_HWIP] == (4,4,2) else (f"regSDMA{pipe}_QUEUE{queue}", 0)
|
||||
reg, inst = ("regSDMA_GFX", pipe*4+queue) if self.adev.ip_ver[am.SDMA0_HWIP][:2] == (4,4) else (f"regSDMA{pipe}_QUEUE{queue}", 0)
|
||||
|
||||
self.adev.reg(f"{reg}_MINOR_PTR_UPDATE").write(0x1, inst=inst)
|
||||
self.adev.wreg_pair(f"{reg}_RB_RPTR", "", "_HI", 0, inst=inst)
|
||||
@@ -439,7 +439,7 @@ class AM_SDMA(AM_IP):
|
||||
self.adev.reg(f"{reg}_DOORBELL_OFFSET").update(offset=doorbell * 2, inst=inst)
|
||||
self.adev.reg(f"{reg}_DOORBELL").update(enable=1, inst=inst)
|
||||
self.adev.reg(f"{reg}_MINOR_PTR_UPDATE").write(0x0, inst=inst)
|
||||
self.adev.reg(f"{reg}_RB_CNTL").write(**({f'{self.sdma_name.lower()}_wptr_poll_enable':1} if self.adev.ip_ver[am.SDMA0_HWIP] != (4,4,2) else {}),
|
||||
self.adev.reg(f"{reg}_RB_CNTL").write(**({f'{self.sdma_name.lower()}_wptr_poll_enable':1} if self.adev.ip_ver[am.SDMA0_HWIP][:2]!=(4,4) else {}),
|
||||
rb_vmid=0, rptr_writeback_enable=1, rptr_writeback_timer=4, rb_enable=1, rb_priv=1, rb_size=(ring_size//4).bit_length()-1, inst=inst)
|
||||
self.adev.reg(f"{reg}_IB_CNTL").update(ib_enable=1, inst=inst)
|
||||
return self.adev.reg(f"{reg}_RB_WPTR").read() | (self.adev.reg(f"{reg}_RB_WPTR_HI").read() << 32)
|
||||
|
||||
@@ -38,8 +38,9 @@ def fixup_ip_version(ip:str, version:tuple[int, ...]) -> list[tuple[int, ...]]:
|
||||
return version
|
||||
|
||||
if ip in ['nbio', 'nbif']: version = _apply_ovrd({(3,3): (2,3,0), (7,3): (7,2,0)})
|
||||
elif ip in ['mp', 'smu']: version = _apply_ovrd({(14,0,3): (14,0,2)})
|
||||
elif ip in ['mp', 'smu']: version = _apply_ovrd({(14,0,3): (14,0,2), (13,0,12): (13,0,6)})
|
||||
elif ip in ['gc']: version = _apply_ovrd({(9,5,0): (9,4,3)})
|
||||
elif ip in ['sdma']: version = _apply_ovrd({(4,4,4): (4,4,2)})
|
||||
|
||||
return [version, version[:2], version[:2]+(0,), version[:1]+(0, 0)]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user