mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
am: mi3xx recovery (#15051)
This commit is contained in:
@@ -193,7 +193,7 @@ class AMDev(PCIDevImplBase):
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: boot done")
|
||||
|
||||
def init_sw(self, smi_dev=False):
|
||||
self.smi_dev, self.is_err_state, self.has_aql_queue = smi_dev, False, False
|
||||
self.smi_dev, self.is_err_state = smi_dev, False
|
||||
|
||||
# Memory manager & firmware
|
||||
self.mm = AMMemoryManager(self, self.vram_size - self.reserved_vram_size, boot_size=(32 << 20), pt_t=AMPageTableEntry, va_shifts=[12, 21, 30, 39],
|
||||
@@ -226,7 +226,7 @@ class AMDev(PCIDevImplBase):
|
||||
self.reg("regSCRATCH_REG6").write(self.is_err_state) # set finalized state.
|
||||
|
||||
def recover(self) -> bool:
|
||||
if (self.has_aql_queue and self.is_hive()) or not self.is_err_state: return False # TODO: support aql queue recovery on hive
|
||||
if not self.is_err_state: return False
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: Start recovery")
|
||||
self.ih.interrupt_handler()
|
||||
self.gfx.reset_mec()
|
||||
|
||||
@@ -243,7 +243,7 @@ class AM_GFX(AM_IP):
|
||||
while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] != 0: pass
|
||||
|
||||
self.adev.gmc.init_hub("GC", inst_cnt=self.xccs)
|
||||
if self.adev.partial_boot: return
|
||||
if self.adev.partial_boot: return self.reset_mec()
|
||||
|
||||
self._config_mec()
|
||||
|
||||
@@ -291,18 +291,22 @@ class AM_GFX(AM_IP):
|
||||
|
||||
def reset_mec(self):
|
||||
self._dequeue_hqds(reset=True)
|
||||
|
||||
# issue a soft reset to reset aql sync counter on multixcc systems.
|
||||
if self.xccs > 1:
|
||||
for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(soft_reset_cp=1, soft_reset_gfx=1, inst=xcc)
|
||||
time.sleep(0.05)
|
||||
for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(0x0, inst=xcc)
|
||||
|
||||
self._config_mec()
|
||||
self._enable_mec()
|
||||
|
||||
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, idx:int, aql:bool) -> tuple[int, int]:
|
||||
self.adev.has_aql_queue |= aql
|
||||
pipe, queue, doorbell = idx // 4, idx % 4, am.AMDGPU_NAVI10_DOORBELL_MEC_RING0
|
||||
self._grbm_select(me=1, pipe=pipe, queue=queue, inst=0)
|
||||
restore_queue = aql and self.xccs > 1 and self.adev.partial_boot and (self.adev.regCP_HQD_ACTIVE.read(inst=0) & 1)
|
||||
restore_ptr = (self.adev.regCP_HQD_PQ_WPTR_LO.read(inst=0) | (self.adev.regCP_HQD_PQ_WPTR_HI.read(inst=0) << 32)) if restore_queue else 0
|
||||
if DEBUG >= 2 and restore_queue: print(f"am {self.adev.devfmt}: GFX queue already active, continuing from saved state {restore_ptr=:#x}.")
|
||||
|
||||
for xcc in range(self.xccs if aql else 1):
|
||||
self._grbm_select(me=1, pipe=pipe, queue=queue, inst=xcc)
|
||||
|
||||
struct_t = getattr(am, f"struct_v{self.adev.ip_ver[am.GC_HWIP][0]}{'_compute' if self.adev.ip_ver[am.GC_HWIP][0] >= 10 else ''}_mqd")
|
||||
mqd_struct = struct_t(header=0xC0310800, cp_mqd_base_addr_lo=lo32(self.mqd_mc[queue] + 0x1000*xcc),
|
||||
cp_mqd_base_addr_hi=hi32(self.mqd_mc[queue] + 0x1000*xcc), cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111,
|
||||
@@ -320,26 +324,16 @@ class AM_GFX(AM_IP):
|
||||
**({'compute_tg_chunk_size':1, 'compute_current_logic_xcc_id':xcc, 'cp_mqd_stride_size':0x1000} if aql and self.xccs > 1 else {}))
|
||||
for se in range(8 if self.adev.ip_ver[am.GC_HWIP][0] >= 10 else 4): setattr(mqd_struct, f'compute_static_thread_mgmt_se{se}', 0xffffffff)
|
||||
|
||||
# Copy mqd into memory
|
||||
self._grbm_select(me=1, pipe=pipe, queue=queue, inst=xcc)
|
||||
self.adev.vram.view(self.mqd_paddr[queue] + 0x1000*xcc, ctypes.sizeof(mqd_struct))[:] = memoryview(mqd_struct).cast('B')
|
||||
|
||||
if restore_queue:
|
||||
for r in [self.adev.regCP_HQD_PQ_RPTR_REPORT_ADDR, self.adev.regCP_HQD_EOP_BASE_ADDR, self.adev.regCP_HQD_EOP_BASE_ADDR_HI,
|
||||
self.adev.regCP_HQD_PQ_RPTR_REPORT_ADDR_HI, self.adev.regCP_HQD_PQ_WPTR_POLL_ADDR, self.adev.regCP_HQD_PQ_WPTR_POLL_ADDR_HI]:
|
||||
val = memoryview(bytes(mqd_struct)).cast('I')[0x80 + (off:=r.addr[xcc] - self.adev.regCP_MQD_BASE_ADDR.addr[xcc])]
|
||||
self.adev.vram.view(self.mqd_paddr[queue] + 0x1000*xcc, ctypes.sizeof(mqd_struct), fmt='I')[0x80 + off] = val
|
||||
r.write(val, inst=xcc)
|
||||
else:
|
||||
self.adev.vram.view(self.mqd_paddr[queue] + 0x1000*xcc, ctypes.sizeof(mqd_struct))[:] = memoryview(mqd_struct).cast('B')
|
||||
|
||||
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
|
||||
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr[xcc], self.adev.regCP_HQD_PQ_WPTR_HI.addr[xcc] + 1)):
|
||||
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
|
||||
self.adev.regCP_HQD_ACTIVE.write(0x1, inst=xcc)
|
||||
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
|
||||
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr[xcc], self.adev.regCP_HQD_PQ_WPTR_HI.addr[xcc] + 1)):
|
||||
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
|
||||
self.adev.regCP_HQD_ACTIVE.write(0x1, inst=xcc)
|
||||
|
||||
self.adev.gmc.flush_hdp()
|
||||
self._grbm_select(inst=xcc)
|
||||
return restore_ptr // 16, doorbell
|
||||
return 0, doorbell
|
||||
|
||||
def set_clockgating_state(self):
|
||||
if hasattr(self.adev, 'regMM_ATC_L2_MISC_CG'): self.adev.regMM_ATC_L2_MISC_CG.write(enable=1, mem_ls_enable=1)
|
||||
@@ -391,14 +385,12 @@ class AM_GFX(AM_IP):
|
||||
_config_helper(eng_name="MEC", cntl_reg="MEC_RS64", eng_reg="MEC_RS64", pipe_cnt=1, me=1, xcc=xcc)
|
||||
|
||||
def _dequeue_hqds(self, reset=False):
|
||||
# NOTE: For aqls with xccs (queue=1), will continue from the saved state.
|
||||
for q in range(2 if self.xccs == 1 else 1):
|
||||
for q in range(2):
|
||||
for xcc in range(self.xccs):
|
||||
self._grbm_select(me=1, pipe=0, queue=q, inst=xcc)
|
||||
if self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1:
|
||||
self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2, inst=xcc) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
|
||||
if reset: self.adev.regSPI_COMPUTE_QUEUE_RESET.write(1, inst=xcc)
|
||||
else: wait_cond(lambda: self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1, value=0, msg="HQD dequeue timeout")
|
||||
if not reset: wait_cond(lambda: self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1, value=0, msg="HQD dequeue timeout")
|
||||
self._grbm_select()
|
||||
|
||||
class AM_IH(AM_IP):
|
||||
|
||||
Reference in New Issue
Block a user