From d1d9e729fd4ca9f515d65aa14d138cf7cedfab7f Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Wed, 28 May 2025 16:53:31 +0300 Subject: [PATCH] am_smi: mem usage (#10547) --- extra/amdpci/am_smi.py | 47 +++++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/extra/amdpci/am_smi.py b/extra/amdpci/am_smi.py index 6668310808..4eb9907698 100755 --- a/extra/amdpci/am_smi.py +++ b/extra/amdpci/am_smi.py @@ -5,7 +5,7 @@ from tinygrad.helpers import DEBUG, colored, ansilen from tinygrad.runtime.autogen import libc from tinygrad.runtime.autogen.am import am from tinygrad.runtime.support.hcq import MMIOInterface -from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager +from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager, AMPageTableEntry from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA AM_VERSION = 0xA0000004 @@ -27,10 +27,14 @@ def color_temp(temp): def color_voltage(voltage): return colored(f"{voltage/1000:>5.3f}V", "cyan") -def draw_bar(percentage, width=40, fill='█', empty='░'): +def draw_bar(percentage, width=40, fill='|', empty=' ', opt_text='', color='cyan'): filled_width = int(width * percentage) + if not opt_text: opt_text = f'{percentage*100:.1f}%' + bar = fill * filled_width + empty * (width - filled_width) - return f'[{bar}] {percentage*100:5.1f}%' + bar = (bar[:-len(opt_text)] + opt_text) if opt_text else bar + bar = colored(bar[:filled_width], color) + bar[filled_width:] + return f'[{bar}]' def same_line(strs:list[list[str]|None], split=8) -> list[str]: strs = [s for s in strs if s is not None] @@ -175,9 +179,25 @@ class SMICtx: def get_power(self, dev, metrics): return metrics.SmuMetrics.AverageSocketPower, metrics.SmuMetrics.dGPU_W_MAX - def draw(self): + def get_mem_usage(self, dev): + usage = 0 + pt_stack = [dev.mm.root_page_table] + while len(pt_stack) > 0: + pt = pt_stack.pop() + for i in range(512): + entry = pt.entries[i] + + if (entry & am.AMDGPU_PTE_VALID) == 0: continue + if pt.lv!=am.AMDGPU_VM_PTB and not dev.gmc.is_pte_huge_page(entry): + pt_stack.append(AMPageTableEntry(dev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)) + continue + if (entry & am.AMDGPU_PTE_SYSTEM) != 0: continue + usage += (1 << ((9 * (3-pt.lv)) + 12)) + return usage + + def draw(self, once): terminal_width, terminal_height = shutil.get_terminal_size() - if self.prev_terminal_width != terminal_width or self.prev_terminal_height != terminal_height: + if not once and (self.prev_terminal_width != terminal_width or self.prev_terminal_height != terminal_height): os.system('clear') self.prev_terminal_width, self.prev_terminal_height = terminal_width, terminal_height @@ -196,9 +216,14 @@ class SMICtx: [pad(f"PCI State: {dev.pci_state}", col_size)]) continue + mem_used = self.get_mem_usage(dev) + mem_total = dev.vram_size + mem_fmt = f"{mem_used/1024**3:.1f}/{mem_total/1024**3:.1f}G" + device_line = [f"{bold(dev.pcibus)} {trim(self.lspci[dev.pcibus[5:]], col_size - 20)}"] + [pad("", col_size)] activity_line = [f"GFX Activity {draw_bar(self.get_gfx_activity(dev, metrics) / 100, activity_line_width)}"] \ - + [f"MEM Activity {draw_bar(self.get_mem_activity(dev, metrics) / 100, activity_line_width)}"] + + [f"MEM Activity {draw_bar(self.get_mem_activity(dev, metrics) / 100, activity_line_width)}"] \ + + [f"MEM Usage {draw_bar((mem_used / mem_total) / 100, activity_line_width, opt_text=mem_fmt)}"] \ temps_data, temps_data_compact = self.get_temps(dev, metrics), self.get_temps(dev, metrics, compact=True) temps_table = ["=== Temps (°C) ==="] + [f"{name:<16}: {color_temp(val)}" for name, val in temps_data.items()] @@ -208,8 +233,8 @@ class SMICtx: power_table = ["=== Power ==="] + [f"Fan Speed: {fan_rpm} RPM"] + [f"Fan Power: {fan_pwm}%"] total_power, max_power = self.get_power(dev, metrics) - power_line = [f"Power: {total_power:>3}W " + draw_bar(total_power / max_power, 16)] - power_line_compact = [f"Power: {total_power:>3}W " + draw_bar(total_power / max_power, activity_line_width)] + power_line = [f"Power: " + draw_bar(total_power / max_power, 16, opt_text=f"{total_power}/{max_power}W")] + power_line_compact = [f"Power: " + draw_bar(total_power / max_power, activity_line_width, opt_text=f"{total_power}/{max_power}W")] voltage_data = self.get_voltage(dev, metrics) voltage_table = ["=== Voltages ==="] + [f"{name:<20}: {color_voltage(voltage)}" for name, voltage in voltage_data.items()] @@ -252,6 +277,7 @@ class SMICtx: if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--list", action="store_true", help="Run once and exit") parser.add_argument("--pids", action="store_true", help="Print pids for all AM devices") parser.add_argument("--kill", action="store_true", help="Kill all pids associated with AM devices. Valid only with --pids") parser.add_argument("--dev", type=str, default=None, help="PCI bus ID of the AM device to monitor (e.g., 0000:01:00.0)") @@ -276,10 +302,11 @@ if __name__ == "__main__": sys.exit(0) try: - os.system('clear') + if not args.list: os.system('clear') smi_ctx = SMICtx() while True: smi_ctx.rescan_devs() - smi_ctx.draw() + smi_ctx.draw(args.list) + if args.list: break time.sleep(1) except KeyboardInterrupt: print("Exiting...")