mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
am: recover from any boot interrupt (#8703)
* am: recover from any load interrupt * add fuzzer * nu
This commit is contained in:
39
test/external/external_fuzz_am_interrupts.py
vendored
Normal file
39
test/external/external_fuzz_am_interrupts.py
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
import subprocess
|
||||
import random
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
def run_test(i, full_run=False):
|
||||
print(f"\rRunning iteration {i}...", end=" ", flush=True)
|
||||
|
||||
p = subprocess.Popen(['python3', 'test/test_tiny.py', 'TestTiny.test_plus'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
if not full_run:
|
||||
time.sleep(random.uniform(0, 1200) / 1000)
|
||||
p.kill()
|
||||
_, stderr = p.communicate()
|
||||
else:
|
||||
_, stderr = p.communicate()
|
||||
|
||||
if full_run:
|
||||
stderr_text = stderr.decode()
|
||||
print(stderr_text)
|
||||
assert "Ran 1 test in" in stderr_text and "OK" in stderr_text
|
||||
|
||||
max_workers = 4
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
for i in range(1000000):
|
||||
if i % 100 == 0:
|
||||
for future in as_completed(futures):
|
||||
try: future.result()
|
||||
except Exception as e:
|
||||
print(f"\nError in iteration: {e}")
|
||||
futures = []
|
||||
|
||||
run_test(i, True)
|
||||
else:
|
||||
future = executor.submit(run_test, i, False)
|
||||
futures.append(future)
|
||||
|
||||
if len(futures) > max_workers * 2: futures = [f for f in futures if not f.done()]
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, collections, time, dataclasses, pathlib, fcntl, os, signal
|
||||
import ctypes, collections, time, dataclasses, pathlib, fcntl, os
|
||||
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
|
||||
from tinygrad.runtime.autogen.am import am, mp_11_0, mp_13_0_0, nbio_4_3_0, mmhub_3_0_0, gc_11_0_0, osssys_6_0_0
|
||||
from tinygrad.runtime.support.allocator import TLSFAllocator
|
||||
@@ -279,13 +279,10 @@ class AMDev:
|
||||
self.partial_boot = False
|
||||
|
||||
if not self.partial_boot:
|
||||
try: # do not interrupt the boot process
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
if self.psp.is_sos_alive(): self.smu.mode1_reset()
|
||||
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
|
||||
ip.init()
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
|
||||
finally: signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||
if self.psp.is_sos_alive() and self.smu.is_smu_alive(): self.smu.mode1_reset()
|
||||
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
|
||||
ip.init()
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
|
||||
|
||||
# Booting done
|
||||
self.is_booting = False
|
||||
@@ -332,8 +329,8 @@ class AMDev:
|
||||
self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg)
|
||||
self.reg("regBIF_BX_PF0_RSMU_DATA").write(val)
|
||||
|
||||
def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff) -> int:
|
||||
for _ in range(10000):
|
||||
def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff, timeout=10000) -> int:
|
||||
for _ in range(timeout):
|
||||
if ((rval:=reg.read()) & mask) == value: return rval
|
||||
time.sleep(0.001)
|
||||
raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import ctypes, time
|
||||
import ctypes, time, contextlib
|
||||
from typing import Literal
|
||||
from tinygrad.runtime.autogen.am import am, smu_v13_0_0
|
||||
from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG
|
||||
@@ -106,22 +106,26 @@ class AM_SMU(AM_IP):
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck, poll=True)
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck, poll=True)
|
||||
|
||||
def is_smu_alive(self):
|
||||
with contextlib.suppress(RuntimeError): self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100)
|
||||
return self.adev.mmMP1_SMN_C2PMSG_90.read() != 0
|
||||
|
||||
def mode1_reset(self):
|
||||
if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset")
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True)
|
||||
time.sleep(0.5) # 500ms
|
||||
|
||||
def _smu_cmn_poll_stat(self): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1)
|
||||
def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout)
|
||||
def _smu_cmn_send_msg(self, msg, param=0):
|
||||
self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg
|
||||
self.adev.mmMP1_SMN_C2PMSG_82.write(param)
|
||||
self.adev.mmMP1_SMN_C2PMSG_66.write(msg)
|
||||
|
||||
def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False):
|
||||
if poll: self._smu_cmn_poll_stat()
|
||||
def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False, timeout=10000): # 10s
|
||||
if poll: self._smu_cmn_poll_stat(timeout=timeout)
|
||||
|
||||
self._smu_cmn_send_msg(msg, param)
|
||||
self._smu_cmn_poll_stat()
|
||||
self._smu_cmn_poll_stat(timeout=timeout)
|
||||
return self.adev.rreg(self.adev.mmMP1_SMN_C2PMSG_82) if read_back_arg else None
|
||||
|
||||
class AM_GFX(AM_IP):
|
||||
@@ -319,8 +323,9 @@ class AM_PSP(AM_IP):
|
||||
(am.PSP_FW_TYPE_PSP_INTF_DRV, am.PSP_BL__LOAD_INTFDRV), (am.PSP_FW_TYPE_PSP_DBG_DRV, am.PSP_BL__LOAD_DBGDRV),
|
||||
(am.PSP_FW_TYPE_PSP_RAS_DRV, am.PSP_BL__LOAD_RASDRV), (am.PSP_FW_TYPE_PSP_SOS, am.PSP_BL__LOAD_SOSDRV)]
|
||||
|
||||
for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid)
|
||||
while not self.is_sos_alive(): time.sleep(0.01)
|
||||
if not self.is_sos_alive():
|
||||
for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid)
|
||||
while not self.is_sos_alive(): time.sleep(0.01)
|
||||
|
||||
self._ring_create()
|
||||
self._tmr_init()
|
||||
@@ -357,6 +362,13 @@ class AM_PSP(AM_IP):
|
||||
self.tmr_paddr = self.adev.mm.palloc(self.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True)
|
||||
|
||||
def _ring_create(self):
|
||||
# If the ring is already created, destroy it
|
||||
if self.adev.regMP0_SMN_C2PMSG_71.read() != 0:
|
||||
self.adev.regMP0_SMN_C2PMSG_64.write(am.GFX_CTRL_CMD_ID_DESTROY_RINGS)
|
||||
|
||||
# There might be handshake issue with hardware which needs delay
|
||||
time.sleep(0.02)
|
||||
|
||||
# Wait until the sOS is ready
|
||||
self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_64, mask=0x80000000, value=0x80000000)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user