mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
am: do not reload fw each run (#8466)
* am do not reload fw each run * works * comment this * clean + comment * warn message * linter * move out pci en master * useless * more correct * oops * oops
This commit is contained in:
@@ -467,6 +467,9 @@ class PCIIface:
|
||||
self.adev = AMDev(self.pcidev, self._map_pci_range(0), dbell:=self._map_pci_range(2).cast('Q'), self._map_pci_range(5).cast('I'))
|
||||
self.doorbell_cpu_addr = mv_address(dbell)
|
||||
|
||||
libpciaccess.pci_device_cfg_read_u16(self.adev.pcidev, ctypes.byref(val:=ctypes.c_uint16()), libpciaccess.PCI_COMMAND)
|
||||
libpciaccess.pci_device_cfg_write_u16(self.adev.pcidev, val.value | libpciaccess.PCI_COMMAND_MASTER, libpciaccess.PCI_COMMAND)
|
||||
|
||||
# TODO: this is for 7900xtx, the only tested card.
|
||||
self.props = {'simd_count': 192, 'simd_per_cu': 2, 'max_waves_per_simd': 16, 'gfx_target_version': 110000, 'max_slots_scratch_cu': 32,
|
||||
'array_count': 12, 'simd_arrays_per_engine': 2, 'lds_size_in_kb': 64}
|
||||
|
||||
@@ -127,8 +127,9 @@ class AMMemoryManager:
|
||||
|
||||
def __init__(self, adev, vram_size:int):
|
||||
self.adev, self.vram_size = adev, vram_size
|
||||
self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device
|
||||
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device
|
||||
self.root_page_table = AMPageTableEntry(self.palloc(0x1000, zero=True), lv=am.AMDGPU_VM_PDB1)
|
||||
self.root_page_table = AMPageTableEntry(self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1)
|
||||
|
||||
def page_table_walker(self, page_table, vaddr, size, offset=0, free_pt=False, creat_pt=True):
|
||||
"""
|
||||
@@ -241,8 +242,9 @@ class AMMemoryManager:
|
||||
self.va_allocator.free(vm.va_addr)
|
||||
if vm.paddr is not None: self.pa_allocator.free(vm.paddr)
|
||||
|
||||
def palloc(self, size, align=0x1000, zero=True) -> AMPhysicalMemoryBlock:
|
||||
pm = AMPhysicalMemoryBlock(self.adev, self.pa_allocator.alloc(round_up(size, 0x1000), align), size)
|
||||
def palloc(self, size, align=0x1000, zero=True, boot=False) -> AMPhysicalMemoryBlock:
|
||||
assert self.adev.is_booting == boot, "During booting, only boot memory can be allocated"
|
||||
pm = AMPhysicalMemoryBlock(self.adev, (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align), size)
|
||||
if zero: ctypes.memset(pm.cpu_addr(), 0, pm.size)
|
||||
return pm
|
||||
|
||||
@@ -256,6 +258,18 @@ class AMDev:
|
||||
self._run_discovery()
|
||||
self._build_regs()
|
||||
|
||||
# AM boot Process:
|
||||
# The GPU being passed can be in one of several states: 1. Not initialized. 2. Initialized by amdgpu. 3. Initialized by AM.
|
||||
# The 1st and 2nd states require a full GPU setup since their states are unknown. The 2nd state also requires a mode1 reset to
|
||||
# reinitialize all components.
|
||||
#
|
||||
# The 3rd state can be set up partially to optimize boot time. In this case, only the GFX and SDMA IPs need to be initialized.
|
||||
# To enable this, AM uses a separate boot memory that is guaranteed not to be overwritten. This physical memory is utilized for
|
||||
# all blocks that are initialized only during the initial AM boot.
|
||||
# To determine if the GPU is in the third state, AM uses regSCRATCH_REG7 as a flag.
|
||||
self.is_booting = True
|
||||
self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000001)) and (getenv("AM_RESET", 0) != 1)
|
||||
|
||||
# Memory manager & firmware
|
||||
self.mm = AMMemoryManager(self, self.vram_size)
|
||||
self.fw = AMFirmware()
|
||||
@@ -269,11 +283,21 @@ class AMDev:
|
||||
self.gfx:AM_GFX = AM_GFX(self)
|
||||
self.sdma:AM_SDMA = AM_SDMA(self)
|
||||
|
||||
if self.psp.is_sos_alive(): self.smu.mode1_reset()
|
||||
if self.partial_boot and (self.reg("regCP_MEC_RS64_CNTL").read() & gc_11_0_0.CP_MEC_RS64_CNTL__MEC_HALT_MASK == 0):
|
||||
print("am: MEC is active. Someone might be using the GPU? Issue a full reset.")
|
||||
self.partial_boot = False
|
||||
|
||||
# Initialize all blocks
|
||||
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu, self.gfx, self.sdma]: ip.init()
|
||||
if not self.partial_boot:
|
||||
if self.psp.is_sos_alive(): self.smu.mode1_reset()
|
||||
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]: ip.init()
|
||||
|
||||
# Booting done
|
||||
self.is_booting = False
|
||||
|
||||
# Re-initialize main blocks
|
||||
for ip in [self.gfx, self.sdma]: ip.init()
|
||||
self.gfx.set_clockgating_state()
|
||||
self.reg("regSCRATCH_REG7").write(am_version)
|
||||
|
||||
def fini(self):
|
||||
for ip in [self.sdma, self.gfx]: ip.fini()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import ctypes, time
|
||||
from typing import Literal
|
||||
from tinygrad.runtime.autogen import libpciaccess
|
||||
from tinygrad.runtime.autogen.am import am, smu_v13_0_0
|
||||
from tinygrad.helpers import to_mv, data64, lo32, hi32
|
||||
|
||||
@@ -26,8 +25,8 @@ class AM_GMC(AM_IP):
|
||||
self.vm_base = self.adev.mm.va_allocator.base
|
||||
self.vm_end = self.vm_base + self.adev.mm.va_allocator.size - 1
|
||||
|
||||
self.memscratch_pm = self.adev.mm.palloc(0x1000)
|
||||
self.dummy_page_pm = self.adev.mm.palloc(0x1000)
|
||||
self.memscratch_pm = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
|
||||
self.dummy_page_pm = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
|
||||
self.hub_initted = {"MM": False, "GC": False}
|
||||
|
||||
def init(self): self.init_hub("MM")
|
||||
@@ -230,9 +229,14 @@ class AM_GFX(AM_IP):
|
||||
_config_helper(eng_name="MEC", cntl_reg="MEC_RS64", eng_reg="MEC_RS64", pipe_cnt=4, me=1)
|
||||
|
||||
class AM_IH(AM_IP):
|
||||
def __init__(self, adev):
|
||||
super().__init__(adev)
|
||||
self.rings = [(self.adev.mm.palloc(512 << 10, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0),
|
||||
(self.adev.mm.palloc(512 << 10, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)]
|
||||
|
||||
def interrupt_handler(self):
|
||||
ring_vm, rwptr_vm, suf, _ = self.rings[0]
|
||||
wptr = to_mv(rwptr_vm.cpu_addr, 8).cast('Q')[0]
|
||||
wptr = to_mv(rwptr_vm.cpu_addr(), 8).cast('Q')[0]
|
||||
|
||||
if self.adev.reg(f"regIH_RB_WPTR{suf}").read(rb_overflow=1):
|
||||
self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0)
|
||||
@@ -241,16 +245,13 @@ class AM_IH(AM_IP):
|
||||
self.adev.regIH_RB_RPTR.write(wptr % ring_vm.size)
|
||||
|
||||
def init(self):
|
||||
self.rings = [(self.adev.mm.valloc(512 << 10, uncached=True, contigous=True), self.adev.mm.valloc(0x1000, uncached=True, contigous=True), "", 0),
|
||||
(self.adev.mm.valloc(512 << 10, uncached=True, contigous=True), self.adev.mm.valloc(0x1000, uncached=True, contigous=True), "_RING1", 1)]
|
||||
|
||||
for ring_vm, rwptr_vm, suf, ring_id in self.rings:
|
||||
self.adev.wreg_pair("regIH_RB_BASE", suf, f"_HI{suf}", ring_vm.va_addr >> 8)
|
||||
self.adev.wreg_pair("regIH_RB_BASE", suf, f"_HI{suf}", ring_vm.mc_addr() >> 8)
|
||||
|
||||
self.adev.reg(f"regIH_RB_CNTL{suf}").write(mc_space=4, wptr_overflow_clear=1, rb_size=(ring_vm.size//4).bit_length(),
|
||||
mc_snoop=1, mc_ro=0, mc_vmid=0, **({'wptr_overflow_enable': 1, 'rptr_rearm': 1} if ring_id == 0 else {'rb_full_drain_enable': 1}))
|
||||
|
||||
if ring_id == 0: self.adev.wreg_pair("regIH_RB_WPTR_ADDR", "_LO", "_HI", rwptr_vm.va_addr)
|
||||
if ring_id == 0: self.adev.wreg_pair("regIH_RB_WPTR_ADDR", "_LO", "_HI", rwptr_vm.mc_addr())
|
||||
|
||||
self.adev.reg(f"regIH_RB_WPTR{suf}").write(0)
|
||||
self.adev.reg(f"regIH_RB_RPTR{suf}").write(0)
|
||||
@@ -261,9 +262,6 @@ class AM_IH(AM_IP):
|
||||
self.adev.regIH_INT_FLOOD_CNTL.update(flood_cntl_enable=1)
|
||||
self.adev.regIH_MSI_STORM_CTRL.update(delay=3)
|
||||
|
||||
libpciaccess.pci_device_cfg_read_u16(self.adev.pcidev, ctypes.byref(val:=ctypes.c_uint16()), libpciaccess.PCI_COMMAND)
|
||||
libpciaccess.pci_device_cfg_write_u16(self.adev.pcidev, val.value | libpciaccess.PCI_COMMAND_MASTER, libpciaccess.PCI_COMMAND)
|
||||
|
||||
# toggle interrupts
|
||||
for _, rwptr_vm, suf, ring_id in self.rings:
|
||||
self.adev.reg(f"regIH_RB_CNTL{suf}").update(rb_enable=1, **({'enable_intr': 1} if ring_id == 0 else {}))
|
||||
@@ -271,6 +269,7 @@ class AM_IH(AM_IP):
|
||||
class AM_SDMA(AM_IP):
|
||||
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int):
|
||||
# Setup the ring
|
||||
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_MINOR_PTR_UPDATE").write(0x1)
|
||||
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_RPTR", "", "_HI", 0)
|
||||
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_WPTR", "", "_HI", 0)
|
||||
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_BASE", "", "_HI", ring_addr >> 8)
|
||||
@@ -278,6 +277,7 @@ class AM_SDMA(AM_IP):
|
||||
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_WPTR_POLL_ADDR", "_LO", "_HI", wptr_addr)
|
||||
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_DOORBELL_OFFSET").update(offset=doorbell * 2)
|
||||
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_DOORBELL").update(enable=1)
|
||||
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_MINOR_PTR_UPDATE").write(0x0)
|
||||
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_RB_CNTL").write(rb_vmid=0, rptr_writeback_enable=1, rptr_writeback_timer=4,
|
||||
f32_wptr_poll_enable=1, rb_size=(ring_size//4).bit_length()-1, rb_enable=1, rb_priv=1)
|
||||
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_IB_CNTL").update(ib_enable=1)
|
||||
@@ -302,10 +302,10 @@ class AM_PSP(AM_IP):
|
||||
def __init__(self, adev):
|
||||
super().__init__(adev)
|
||||
|
||||
self.msg1_pm = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG)
|
||||
self.cmd_pm = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE)
|
||||
self.fence_pm = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE)
|
||||
self.ring_pm = self.adev.mm.palloc(0x10000)
|
||||
self.msg1_pm = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=not self.adev.partial_boot, boot=True)
|
||||
self.cmd_pm = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
|
||||
self.fence_pm = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
|
||||
self.ring_pm = self.adev.mm.palloc(0x10000, zero=not self.adev.partial_boot, boot=True)
|
||||
|
||||
def is_sos_alive(self): return self.adev.regMP0_SMN_C2PMSG_81.read() != 0x0
|
||||
def init(self):
|
||||
@@ -351,7 +351,7 @@ class AM_PSP(AM_IP):
|
||||
self._prep_msg1(fwm:=self.adev.fw.sos_fw[am.PSP_FW_TYPE_PSP_TOC])
|
||||
resp = self._load_toc_cmd(len(fwm))
|
||||
|
||||
self.tmr_pm = self.adev.mm.palloc(resp.resp.tmr_size, align=am.PSP_TMR_ALIGNMENT)
|
||||
self.tmr_pm = self.adev.mm.palloc(resp.resp.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True)
|
||||
|
||||
def _ring_create(self):
|
||||
# Wait until the sOS is ready
|
||||
|
||||
Reference in New Issue
Block a user