am_smi: mem usage (#10547)

This commit is contained in:
nimlgen
2025-05-28 16:53:31 +03:00
committed by GitHub
parent 23e41f523a
commit d1d9e729fd

View File

@@ -5,7 +5,7 @@ from tinygrad.helpers import DEBUG, colored, ansilen
from tinygrad.runtime.autogen import libc
from tinygrad.runtime.autogen.am import am
from tinygrad.runtime.support.hcq import MMIOInterface
from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager
from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager, AMPageTableEntry
from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
AM_VERSION = 0xA0000004
@@ -27,10 +27,14 @@ def color_temp(temp):
def color_voltage(voltage): return colored(f"{voltage/1000:>5.3f}V", "cyan")
def draw_bar(percentage, width=40, fill='', empty=''):
def draw_bar(percentage, width=40, fill='|', empty=' ', opt_text='', color='cyan'):
filled_width = int(width * percentage)
if not opt_text: opt_text = f'{percentage*100:.1f}%'
bar = fill * filled_width + empty * (width - filled_width)
return f'[{bar}] {percentage*100:5.1f}%'
bar = (bar[:-len(opt_text)] + opt_text) if opt_text else bar
bar = colored(bar[:filled_width], color) + bar[filled_width:]
return f'[{bar}]'
def same_line(strs:list[list[str]|None], split=8) -> list[str]:
strs = [s for s in strs if s is not None]
@@ -175,9 +179,25 @@ class SMICtx:
def get_power(self, dev, metrics): return metrics.SmuMetrics.AverageSocketPower, metrics.SmuMetrics.dGPU_W_MAX
def draw(self):
def get_mem_usage(self, dev):
usage = 0
pt_stack = [dev.mm.root_page_table]
while len(pt_stack) > 0:
pt = pt_stack.pop()
for i in range(512):
entry = pt.entries[i]
if (entry & am.AMDGPU_PTE_VALID) == 0: continue
if pt.lv!=am.AMDGPU_VM_PTB and not dev.gmc.is_pte_huge_page(entry):
pt_stack.append(AMPageTableEntry(dev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1))
continue
if (entry & am.AMDGPU_PTE_SYSTEM) != 0: continue
usage += (1 << ((9 * (3-pt.lv)) + 12))
return usage
def draw(self, once):
terminal_width, terminal_height = shutil.get_terminal_size()
if self.prev_terminal_width != terminal_width or self.prev_terminal_height != terminal_height:
if not once and (self.prev_terminal_width != terminal_width or self.prev_terminal_height != terminal_height):
os.system('clear')
self.prev_terminal_width, self.prev_terminal_height = terminal_width, terminal_height
@@ -196,9 +216,14 @@ class SMICtx:
[pad(f"PCI State: {dev.pci_state}", col_size)])
continue
mem_used = self.get_mem_usage(dev)
mem_total = dev.vram_size
mem_fmt = f"{mem_used/1024**3:.1f}/{mem_total/1024**3:.1f}G"
device_line = [f"{bold(dev.pcibus)} {trim(self.lspci[dev.pcibus[5:]], col_size - 20)}"] + [pad("", col_size)]
activity_line = [f"GFX Activity {draw_bar(self.get_gfx_activity(dev, metrics) / 100, activity_line_width)}"] \
+ [f"MEM Activity {draw_bar(self.get_mem_activity(dev, metrics) / 100, activity_line_width)}"]
+ [f"MEM Activity {draw_bar(self.get_mem_activity(dev, metrics) / 100, activity_line_width)}"] \
+ [f"MEM Usage {draw_bar((mem_used / mem_total) / 100, activity_line_width, opt_text=mem_fmt)}"] \
temps_data, temps_data_compact = self.get_temps(dev, metrics), self.get_temps(dev, metrics, compact=True)
temps_table = ["=== Temps (°C) ==="] + [f"{name:<16}: {color_temp(val)}" for name, val in temps_data.items()]
@@ -208,8 +233,8 @@ class SMICtx:
power_table = ["=== Power ==="] + [f"Fan Speed: {fan_rpm} RPM"] + [f"Fan Power: {fan_pwm}%"]
total_power, max_power = self.get_power(dev, metrics)
power_line = [f"Power: {total_power:>3}W " + draw_bar(total_power / max_power, 16)]
power_line_compact = [f"Power: {total_power:>3}W " + draw_bar(total_power / max_power, activity_line_width)]
power_line = [f"Power: " + draw_bar(total_power / max_power, 16, opt_text=f"{total_power}/{max_power}W")]
power_line_compact = [f"Power: " + draw_bar(total_power / max_power, activity_line_width, opt_text=f"{total_power}/{max_power}W")]
voltage_data = self.get_voltage(dev, metrics)
voltage_table = ["=== Voltages ==="] + [f"{name:<20}: {color_voltage(voltage)}" for name, voltage in voltage_data.items()]
@@ -252,6 +277,7 @@ class SMICtx:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--list", action="store_true", help="Run once and exit")
parser.add_argument("--pids", action="store_true", help="Print pids for all AM devices")
parser.add_argument("--kill", action="store_true", help="Kill all pids associated with AM devices. Valid only with --pids")
parser.add_argument("--dev", type=str, default=None, help="PCI bus ID of the AM device to monitor (e.g., 0000:01:00.0)")
@@ -276,10 +302,11 @@ if __name__ == "__main__":
sys.exit(0)
try:
os.system('clear')
if not args.list: os.system('clear')
smi_ctx = SMICtx()
while True:
smi_ctx.rescan_devs()
smi_ctx.draw()
smi_ctx.draw(args.list)
if args.list: break
time.sleep(1)
except KeyboardInterrupt: print("Exiting...")