diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 669e8bdd70..adcd131bfd 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -187,7 +187,10 @@ class AMDev: self.init_hw(self.gfx, self.sdma) self.pci_dev.write_config(pci.PCI_COMMAND, self.pci_dev.read_config(pci.PCI_COMMAND, 2) | pci.PCI_COMMAND_MASTER, 2) - self.smu.set_clocks(level=-1) # last level, max perf. + if (max_power:=getenv("AM_POWER_LIMIT", 0.0)) > 0: + self.smu.set_power_limit(max_power) + self.smu.set_clocks(level=None) + else: self.smu.set_clocks(level=-1) # last level, max perf. for ip in [self.soc, self.gfx]: ip.set_clockgating_state() self.reg("regSCRATCH_REG7").write(AMDev.Version) self.reg("regSCRATCH_REG6").write(1) # set initialized state. diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 12bb947460..6119d7dc59 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -203,14 +203,25 @@ class AM_SMU(AM_IP): return {clck: [self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)] for clck in clk_list if (cnt:=self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff)} - def set_clocks(self, level:int): + def set_clocks(self, level:int|None): clks = tuple([self.smu_mod.PPCLK_UCLK, self.smu_mod.PPCLK_FCLK, self.smu_mod.PPCLK_SOCCLK]) if self.adev.ip_ver[am.MP0_HWIP] not in {(13,0,6), (13,0,12)}: clks += (self.smu_mod.PPCLK_GFXCLK,) + if level is None: + for clck in clks: + with contextlib.suppress(TimeoutError): self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16, timeout=20) + if self.adev.ip_ver[am.GC_HWIP] >= (10,0,0): self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | 0xffff) + return + for clck, vals in self.read_clocks(clks).items(): with contextlib.suppress(TimeoutError): self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), timeout=20) if self.adev.ip_ver[am.GC_HWIP] >= (10,0,0): self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level])) + def set_power_limit(self, watts:float): + ppt_limit = max(int(round(watts)), 1) + self._send_msg(self.smu_mod.PPSMC_MSG_SetPptLimit, ppt_limit) + if DEBUG >= 2: print(f"am {self.adev.devfmt}: GPU power limit set to {ppt_limit}W") + def _aca_read_reg(self, bank_idx:int, reg_idx:int, ue=True) -> int: msg = self.smu_mod.PPSMC_MSG_McaBankDumpDW if ue else self.smu_mod.PPSMC_MSG_McaBankCeDumpDW return (self._send_msg(msg, (bank_idx << 16) | (reg_idx * 8 + 4), read_back_arg=True) << 32) | \