am: recover from any boot interrupt (#8703)

* am: recover from any load interrupt

* add fuzzer

* nu
This commit is contained in:
nimlgen
2025-01-21 22:22:23 +03:00
committed by GitHub
parent 1e283c33d3
commit c5e46c5eee
3 changed files with 65 additions and 17 deletions

View File

@@ -0,0 +1,39 @@
import subprocess
import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
def run_test(i, full_run=False):
print(f"\rRunning iteration {i}...", end=" ", flush=True)
p = subprocess.Popen(['python3', 'test/test_tiny.py', 'TestTiny.test_plus'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if not full_run:
time.sleep(random.uniform(0, 1200) / 1000)
p.kill()
_, stderr = p.communicate()
else:
_, stderr = p.communicate()
if full_run:
stderr_text = stderr.decode()
print(stderr_text)
assert "Ran 1 test in" in stderr_text and "OK" in stderr_text
max_workers = 4
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for i in range(1000000):
if i % 100 == 0:
for future in as_completed(futures):
try: future.result()
except Exception as e:
print(f"\nError in iteration: {e}")
futures = []
run_test(i, True)
else:
future = executor.submit(run_test, i, False)
futures.append(future)
if len(futures) > max_workers * 2: futures = [f for f in futures if not f.done()]

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import ctypes, collections, time, dataclasses, pathlib, fcntl, os, signal
import ctypes, collections, time, dataclasses, pathlib, fcntl, os
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
from tinygrad.runtime.autogen.am import am, mp_11_0, mp_13_0_0, nbio_4_3_0, mmhub_3_0_0, gc_11_0_0, osssys_6_0_0
from tinygrad.runtime.support.allocator import TLSFAllocator
@@ -279,13 +279,10 @@ class AMDev:
self.partial_boot = False
if not self.partial_boot:
try: # do not interrupt the boot process
signal.signal(signal.SIGINT, signal.SIG_IGN)
if self.psp.is_sos_alive(): self.smu.mode1_reset()
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
ip.init()
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
finally: signal.signal(signal.SIGINT, signal.default_int_handler)
if self.psp.is_sos_alive() and self.smu.is_smu_alive(): self.smu.mode1_reset()
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
ip.init()
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
# Booting done
self.is_booting = False
@@ -332,8 +329,8 @@ class AMDev:
self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg)
self.reg("regBIF_BX_PF0_RSMU_DATA").write(val)
def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff) -> int:
for _ in range(10000):
def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff, timeout=10000) -> int:
for _ in range(timeout):
if ((rval:=reg.read()) & mask) == value: return rval
time.sleep(0.001)
raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')

View File

@@ -1,4 +1,4 @@
import ctypes, time
import ctypes, time, contextlib
from typing import Literal
from tinygrad.runtime.autogen.am import am, smu_v13_0_0
from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG
@@ -106,22 +106,26 @@ class AM_SMU(AM_IP):
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck, poll=True)
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck, poll=True)
def is_smu_alive(self):
with contextlib.suppress(RuntimeError): self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100)
return self.adev.mmMP1_SMN_C2PMSG_90.read() != 0
def mode1_reset(self):
if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset")
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True)
time.sleep(0.5) # 500ms
def _smu_cmn_poll_stat(self): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1)
def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout)
def _smu_cmn_send_msg(self, msg, param=0):
self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg
self.adev.mmMP1_SMN_C2PMSG_82.write(param)
self.adev.mmMP1_SMN_C2PMSG_66.write(msg)
def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False):
if poll: self._smu_cmn_poll_stat()
def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False, timeout=10000): # 10s
if poll: self._smu_cmn_poll_stat(timeout=timeout)
self._smu_cmn_send_msg(msg, param)
self._smu_cmn_poll_stat()
self._smu_cmn_poll_stat(timeout=timeout)
return self.adev.rreg(self.adev.mmMP1_SMN_C2PMSG_82) if read_back_arg else None
class AM_GFX(AM_IP):
@@ -319,8 +323,9 @@ class AM_PSP(AM_IP):
(am.PSP_FW_TYPE_PSP_INTF_DRV, am.PSP_BL__LOAD_INTFDRV), (am.PSP_FW_TYPE_PSP_DBG_DRV, am.PSP_BL__LOAD_DBGDRV),
(am.PSP_FW_TYPE_PSP_RAS_DRV, am.PSP_BL__LOAD_RASDRV), (am.PSP_FW_TYPE_PSP_SOS, am.PSP_BL__LOAD_SOSDRV)]
for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid)
while not self.is_sos_alive(): time.sleep(0.01)
if not self.is_sos_alive():
for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid)
while not self.is_sos_alive(): time.sleep(0.01)
self._ring_create()
self._tmr_init()
@@ -357,6 +362,13 @@ class AM_PSP(AM_IP):
self.tmr_paddr = self.adev.mm.palloc(self.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True)
def _ring_create(self):
# If the ring is already created, destroy it
if self.adev.regMP0_SMN_C2PMSG_71.read() != 0:
self.adev.regMP0_SMN_C2PMSG_64.write(am.GFX_CTRL_CMD_ID_DESTROY_RINGS)
# There might be handshake issue with hardware which needs delay
time.sleep(0.02)
# Wait until the sOS is ready
self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_64, mask=0x80000000, value=0x80000000)