From 801ec9e69728ffcef9b66ac8d97df23bfc259988 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Tue, 28 Jan 2025 20:18:46 +0300 Subject: [PATCH] am: no hardcoded clocks (#8788) * am: no hardcoded clocks * better --- tinygrad/runtime/support/am/amdev.py | 4 ++-- tinygrad/runtime/support/am/ip.py | 31 ++++++++++++++++------------ 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 9a3480c7ef..14f1643fa4 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -290,14 +290,14 @@ class AMDev: ip.init() if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized") - self.smu.set_clocks(perf=True) + self.smu.set_clocks(level=-1) # last level, max perf. self.gfx.set_clockgating_state() self.reg("regSCRATCH_REG7").write(am_version) if DEBUG >= 2: print(f"am {self.devfmt}: boot done") def fini(self): for ip in [self.sdma, self.gfx]: ip.fini() - self.smu.set_clocks(perf=False) + self.smu.set_clocks(level=0) def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 421b53bcfd..b125f45a56 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -107,29 +107,34 @@ class AM_SMU(AM_IP): self.driver_table_paddr = self.adev.mm.palloc(0x4000, zero=not self.adev.partial_boot, boot=True) def init(self): - self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True) - self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True) - self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True) + self._send_msg(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True) + self._send_msg(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True) + self._send_msg(smu_v13_0_0.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True) def is_smu_alive(self): - with contextlib.suppress(RuntimeError): self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100) + with contextlib.suppress(RuntimeError): self._send_msg(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100) return self.adev.mmMP1_SMN_C2PMSG_90.read() != 0 def mode1_reset(self): if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset") - self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True) + self._send_msg(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True) time.sleep(0.5) # 500ms def read_table(self, table_t, cmd): - self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True) + self._send_msg(smu_v13_0_0.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True) return table_t.from_buffer(to_mv(self.adev.paddr2cpu(self.driver_table_paddr), ctypes.sizeof(table_t))) def read_metrics(self): return self.read_table(smu_v13_0_0.SmuMetricsExternal_t, smu_v13_0_0.TABLE_SMU_METRICS) - def set_clocks(self, perf): - # TODO: Parse from bios. - for clck, (mn, mx) in {smu_v13_0_0.PPCLK_GFXCLK: (0, 3220), smu_v13_0_0.PPCLK_UCLK: (0, 1249), smu_v13_0_0.PPCLK_FCLK: (0, 2301)}.items(): - self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (mx if perf else mn), poll=True) - self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (mx if perf else mn), poll=True) + def set_clocks(self, level): + if not hasattr(self, 'clcks'): + self.clcks = {} + for clck in [smu_v13_0_0.PPCLK_GFXCLK, smu_v13_0_0.PPCLK_UCLK, smu_v13_0_0.PPCLK_FCLK, smu_v13_0_0.PPCLK_SOCCLK]: + cnt = self._send_msg(smu_v13_0_0.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff + self.clcks[clck] = [self._send_msg(smu_v13_0_0.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)] + + for clck, vals in self.clcks.items(): + self._send_msg(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), poll=True) + self._send_msg(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]), poll=True) def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout) def _smu_cmn_send_msg(self, msg, param=0): @@ -137,12 +142,12 @@ class AM_SMU(AM_IP): self.adev.mmMP1_SMN_C2PMSG_82.write(param) self.adev.mmMP1_SMN_C2PMSG_66.write(msg) - def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False, timeout=10000): # 10s + def _send_msg(self, msg, param, poll=True, read_back_arg=False, timeout=10000): # 10s if poll: self._smu_cmn_poll_stat(timeout=timeout) self._smu_cmn_send_msg(msg, param) self._smu_cmn_poll_stat(timeout=timeout) - return self.adev.rreg(self.adev.mmMP1_SMN_C2PMSG_82) if read_back_arg else None + return self.adev.mmMP1_SMN_C2PMSG_82.read() if read_back_arg else None class AM_GFX(AM_IP): def init(self):