From 4e12fc3fe6fae4c5614547eca70256d05d60e595 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Fri, 27 Feb 2026 22:10:47 +0300 Subject: [PATCH] am: mi3xx recovery (#15051) --- tinygrad/runtime/support/am/amdev.py | 4 +-- tinygrad/runtime/support/am/ip.py | 44 ++++++++++++---------------- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 028062eeac..514b5f2281 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -193,7 +193,7 @@ class AMDev(PCIDevImplBase): if DEBUG >= 2: print(f"am {self.devfmt}: boot done") def init_sw(self, smi_dev=False): - self.smi_dev, self.is_err_state, self.has_aql_queue = smi_dev, False, False + self.smi_dev, self.is_err_state = smi_dev, False # Memory manager & firmware self.mm = AMMemoryManager(self, self.vram_size - self.reserved_vram_size, boot_size=(32 << 20), pt_t=AMPageTableEntry, va_shifts=[12, 21, 30, 39], @@ -226,7 +226,7 @@ class AMDev(PCIDevImplBase): self.reg("regSCRATCH_REG6").write(self.is_err_state) # set finalized state. def recover(self) -> bool: - if (self.has_aql_queue and self.is_hive()) or not self.is_err_state: return False # TODO: support aql queue recovery on hive + if not self.is_err_state: return False if DEBUG >= 2: print(f"am {self.devfmt}: Start recovery") self.ih.interrupt_handler() self.gfx.reset_mec() diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 5341756016..393fb2c2a9 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -243,7 +243,7 @@ class AM_GFX(AM_IP): while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] != 0: pass self.adev.gmc.init_hub("GC", inst_cnt=self.xccs) - if self.adev.partial_boot: return + if self.adev.partial_boot: return self.reset_mec() self._config_mec() @@ -291,18 +291,22 @@ class AM_GFX(AM_IP): def reset_mec(self): self._dequeue_hqds(reset=True) + + # issue a soft reset to reset aql sync counter on multixcc systems. + if self.xccs > 1: + for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(soft_reset_cp=1, soft_reset_gfx=1, inst=xcc) + time.sleep(0.05) + for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(0x0, inst=xcc) + self._config_mec() self._enable_mec() def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, idx:int, aql:bool) -> tuple[int, int]: - self.adev.has_aql_queue |= aql pipe, queue, doorbell = idx // 4, idx % 4, am.AMDGPU_NAVI10_DOORBELL_MEC_RING0 - self._grbm_select(me=1, pipe=pipe, queue=queue, inst=0) - restore_queue = aql and self.xccs > 1 and self.adev.partial_boot and (self.adev.regCP_HQD_ACTIVE.read(inst=0) & 1) - restore_ptr = (self.adev.regCP_HQD_PQ_WPTR_LO.read(inst=0) | (self.adev.regCP_HQD_PQ_WPTR_HI.read(inst=0) << 32)) if restore_queue else 0 - if DEBUG >= 2 and restore_queue: print(f"am {self.adev.devfmt}: GFX queue already active, continuing from saved state {restore_ptr=:#x}.") for xcc in range(self.xccs if aql else 1): + self._grbm_select(me=1, pipe=pipe, queue=queue, inst=xcc) + struct_t = getattr(am, f"struct_v{self.adev.ip_ver[am.GC_HWIP][0]}{'_compute' if self.adev.ip_ver[am.GC_HWIP][0] >= 10 else ''}_mqd") mqd_struct = struct_t(header=0xC0310800, cp_mqd_base_addr_lo=lo32(self.mqd_mc[queue] + 0x1000*xcc), cp_mqd_base_addr_hi=hi32(self.mqd_mc[queue] + 0x1000*xcc), cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111, @@ -320,26 +324,16 @@ class AM_GFX(AM_IP): **({'compute_tg_chunk_size':1, 'compute_current_logic_xcc_id':xcc, 'cp_mqd_stride_size':0x1000} if aql and self.xccs > 1 else {})) for se in range(8 if self.adev.ip_ver[am.GC_HWIP][0] >= 10 else 4): setattr(mqd_struct, f'compute_static_thread_mgmt_se{se}', 0xffffffff) - # Copy mqd into memory - self._grbm_select(me=1, pipe=pipe, queue=queue, inst=xcc) + self.adev.vram.view(self.mqd_paddr[queue] + 0x1000*xcc, ctypes.sizeof(mqd_struct))[:] = memoryview(mqd_struct).cast('B') - if restore_queue: - for r in [self.adev.regCP_HQD_PQ_RPTR_REPORT_ADDR, self.adev.regCP_HQD_EOP_BASE_ADDR, self.adev.regCP_HQD_EOP_BASE_ADDR_HI, - self.adev.regCP_HQD_PQ_RPTR_REPORT_ADDR_HI, self.adev.regCP_HQD_PQ_WPTR_POLL_ADDR, self.adev.regCP_HQD_PQ_WPTR_POLL_ADDR_HI]: - val = memoryview(bytes(mqd_struct)).cast('I')[0x80 + (off:=r.addr[xcc] - self.adev.regCP_MQD_BASE_ADDR.addr[xcc])] - self.adev.vram.view(self.mqd_paddr[queue] + 0x1000*xcc, ctypes.sizeof(mqd_struct), fmt='I')[0x80 + off] = val - r.write(val, inst=xcc) - else: - self.adev.vram.view(self.mqd_paddr[queue] + 0x1000*xcc, ctypes.sizeof(mqd_struct))[:] = memoryview(mqd_struct).cast('B') - - mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I') - for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr[xcc], self.adev.regCP_HQD_PQ_WPTR_HI.addr[xcc] + 1)): - self.adev.wreg(reg, mqd_st_mv[0x80 + i]) - self.adev.regCP_HQD_ACTIVE.write(0x1, inst=xcc) + mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I') + for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr[xcc], self.adev.regCP_HQD_PQ_WPTR_HI.addr[xcc] + 1)): + self.adev.wreg(reg, mqd_st_mv[0x80 + i]) + self.adev.regCP_HQD_ACTIVE.write(0x1, inst=xcc) self.adev.gmc.flush_hdp() self._grbm_select(inst=xcc) - return restore_ptr // 16, doorbell + return 0, doorbell def set_clockgating_state(self): if hasattr(self.adev, 'regMM_ATC_L2_MISC_CG'): self.adev.regMM_ATC_L2_MISC_CG.write(enable=1, mem_ls_enable=1) @@ -391,14 +385,12 @@ class AM_GFX(AM_IP): _config_helper(eng_name="MEC", cntl_reg="MEC_RS64", eng_reg="MEC_RS64", pipe_cnt=1, me=1, xcc=xcc) def _dequeue_hqds(self, reset=False): - # NOTE: For aqls with xccs (queue=1), will continue from the saved state. - for q in range(2 if self.xccs == 1 else 1): + for q in range(2): for xcc in range(self.xccs): self._grbm_select(me=1, pipe=0, queue=q, inst=xcc) if self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1: self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2, inst=xcc) # 1 - DRAIN_PIPE; 2 - RESET_WAVES - if reset: self.adev.regSPI_COMPUTE_QUEUE_RESET.write(1, inst=xcc) - else: wait_cond(lambda: self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1, value=0, msg="HQD dequeue timeout") + if not reset: wait_cond(lambda: self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1, value=0, msg="HQD dequeue timeout") self._grbm_select() class AM_IH(AM_IP):