From dafbe9733a1bd3f5f2a3cc7c860f94acda974b3d Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Mon, 2 Mar 2026 17:06:21 +0300 Subject: [PATCH] am: cleanup (#15086) --- tinygrad/runtime/support/am/amdev.py | 4 ++-- tinygrad/runtime/support/am/ip.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 34416cd9af..ee9a5b14f2 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -243,14 +243,14 @@ class AMDev(PCIDevImplBase): def reg(self, reg:str) -> AMRegister: return self.__dict__[reg] def rreg(self, reg:int) -> int: - val = self.indirect_rreg(reg) if reg > len(self.mmio) else self.mmio[reg] + val = self.indirect_rreg(reg) if reg >= len(self.mmio) else self.mmio[reg] if AM_DEBUG >= 4 and getattr(self, '_prev_rreg', None) != (reg, val): print(f"am {self.devfmt}: Reading register {reg:#x} with value {val:#x}") self._prev_rreg = (reg, val) return val def wreg(self, reg:int, val:int): if AM_DEBUG >= 4: print(f"am {self.devfmt}: Writing register {reg:#x} with value {val:#x}") - if reg > len(self.mmio): self.indirect_wreg(reg, val) + if reg >= len(self.mmio): self.indirect_wreg(reg, val) else: self.mmio[reg] = val def wreg_pair(self, reg_base:str, lo_suffix:str, hi_suffix:str, val:int, inst:int=0): diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 534a0376d0..5d4fb2bb7a 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -25,7 +25,7 @@ class AM_SOC(AM_IP): return {getattr(am, k): k[off+9:] for k in dir(am) if k.startswith(f'{pref}_{self.adev.ip_ver[hwip][0]}') and (off:=k.find('__SRCID__')) != -1} gfx_srcs, sdma_srcs = _ih_srcs('GFX', am.GC_HWIP), _ih_srcs('SDMA0', am.SDMA0_HWIP) - self.ih_scrs_names:dict[int, dict[int, str]] = {**{k: gfx_srcs for k in self.gfx_ih_clients}, **{k: sdma_srcs for k in self.sdma_ih_clients}} + self.ih_srcs_names:dict[int, dict[int, str]] = {**{k: gfx_srcs for k in self.gfx_ih_clients}, **{k: sdma_srcs for k in self.sdma_ih_clients}} def init_hw(self): if self.adev.ip_ver[am.NBIO_HWIP] in {(7,9,0), (7,9,1)}: @@ -240,7 +240,8 @@ class AM_GFX(AM_IP): def init_hw(self): # Wait for RLC autoload to complete - while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] != 0: pass + wait_cond(lambda: self.adev.regCP_STAT.read() == 0 or self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] == 0, + value=True, msg="RLC autoload timeout") self.adev.gmc.init_hub("GC", inst_cnt=self.xccs) if self.adev.partial_boot: return self.reset_mec() @@ -285,7 +286,7 @@ class AM_GFX(AM_IP): self._enable_mec() # Set 1 partition - if self.xccs > 1 and not self.adev.partial_boot: self.adev.psp._spatial_partition_cmd(1) + if self.xccs > 1: self.adev.psp._spatial_partition_cmd(1) def fini_hw(self): self._dequeue_hqds() @@ -435,7 +436,7 @@ class AM_IH(AM_IP): [getattr(am, f'SOC15_{n}_FROM_IH_ENTRY')(entry) for n in ['CLIENT_ID', 'SOURCE_ID', 'RING_ID', 'VMID', 'VMID_TYPE', 'PASID', 'NODEID']] ctx = [getattr(am, f'SOC15_CONTEXT_ID{i}_FROM_IH_ENTRY')(entry) for i in range(4)] - src_name = self.adev.soc.ih_scrs_names.get(client, {}).get(src, '') + src_name = self.adev.soc.ih_srcs_names.get(client, {}).get(src, '') print(f"am {self.adev.devfmt}: IH ({rptr:#x}/{wptr['offset']:#x}) client={self.adev.soc.ih_clients.get(client)} src={src_name}({src}) " f"ring={ring_id} vmid={vmid}({vmid_type}) pasid={pasid} node={node} ctx=[{ctx[0]:#x}, {ctx[1]:#x}, {ctx[2]:#x}, {ctx[3]:#x}]")