mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
am: add am_smi (#8739)
* am: start monitor * cleanups * fixes * hmm * progress * cleanup
This commit is contained in:
167
extra/amdpci/am_smi.py
Normal file
167
extra/amdpci/am_smi.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import time, mmap, sys, shutil, os, glob
|
||||
from tinygrad.helpers import to_mv, DEBUG, colored, ansilen
|
||||
from tinygrad.runtime.autogen import libc
|
||||
from tinygrad.runtime.autogen.am import smu_v13_0_0
|
||||
from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager
|
||||
from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
|
||||
|
||||
AM_VERSION = 0xA0000002
|
||||
|
||||
def bold(s): return f"\033[1m{s}\033[0m"
|
||||
|
||||
def color_temp(temp):
|
||||
if temp >= 87: return colored(f"{temp:>4}", "red")
|
||||
elif temp >= 80: return colored(f"{temp:>4}", "yellow")
|
||||
return colored(f"{temp:>4}", "white")
|
||||
|
||||
def color_voltage(voltage): return colored(f"{voltage/1000:>5.3f}V", "cyan")
|
||||
|
||||
def draw_bar(percentage, width=40, fill='█', empty='░'):
|
||||
filled_width = int(width * percentage)
|
||||
bar = fill * filled_width + empty * (width - filled_width)
|
||||
return f'[{bar}] {percentage*100:.1f}%'
|
||||
|
||||
def same_line(strs:list[list[str]], split=8) -> list[str]:
|
||||
ret = []
|
||||
max_width_in_block = [max(ansilen(line) for line in block) for block in strs]
|
||||
max_height = max(len(block) for block in strs)
|
||||
for i in range(max_height):
|
||||
line = []
|
||||
for bid, block in enumerate(strs):
|
||||
if i < len(block): line.append(block[i] + ' ' * (split + max_width_in_block[bid] - ansilen(block[i])))
|
||||
else: line.append(' ' * (split + max_width_in_block[bid]))
|
||||
ret.append(' '.join(line))
|
||||
return ret
|
||||
|
||||
def get_bar0_size(pcibus):
|
||||
resource_file = f"/sys/bus/pci/devices/{pcibus}/resource"
|
||||
if not os.path.exists(resource_file): raise FileNotFoundError(f"Resource file not found: {resource_file}")
|
||||
|
||||
with open(resource_file, "r") as f: lines = f.readlines()
|
||||
bar0_info = lines[0].split()
|
||||
if len(bar0_info) < 3: raise ValueError("Unexpected resource file format for BAR0.")
|
||||
|
||||
start_hex, end_hex, _flags = bar0_info
|
||||
return int(end_hex, 16) - int(start_hex, 16) + 1
|
||||
|
||||
class AMSMI(AMDev):
|
||||
def __init__(self, pcibus, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
|
||||
self.pcibus = pcibus
|
||||
self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
|
||||
|
||||
self._run_discovery()
|
||||
self._build_regs()
|
||||
|
||||
if self.reg("regSCRATCH_REG7").read() != AM_VERSION:
|
||||
raise Exception(f"Unsupported AM version: {self.reg('regSCRATCH_REG7').read():x}")
|
||||
|
||||
self.is_booting, self.smi_dev = True, True
|
||||
self.partial_boot = True # do not init anything
|
||||
self.mm = AMMemoryManager(self, self.vram_size)
|
||||
|
||||
# Initialize IP blocks
|
||||
self.soc21:AM_SOC21 = AM_SOC21(self)
|
||||
self.gmc:AM_GMC = AM_GMC(self)
|
||||
self.ih:AM_IH = AM_IH(self)
|
||||
self.psp:AM_PSP = AM_PSP(self)
|
||||
self.smu:AM_SMU = AM_SMU(self)
|
||||
|
||||
class SMICtx:
|
||||
def __init__(self):
|
||||
self.devs = []
|
||||
self.opened_pcidevs = []
|
||||
self.opened_pci_resources = {}
|
||||
self.prev_lines_cnt = 0
|
||||
|
||||
def _open_am_device(self, pcibus):
|
||||
if pcibus not in self.opened_pci_resources:
|
||||
bar_fds = {bar: os.open(f"/sys/bus/pci/devices/{pcibus}/resource{bar}", os.O_RDWR | os.O_SYNC) for bar in [0, 2, 5]}
|
||||
bar_size = {0: get_bar0_size(pcibus), 2: os.fstat(bar_fds[2]).st_size, 5: os.fstat(bar_fds[5]).st_size}
|
||||
|
||||
def map_pci_range(bar):
|
||||
return to_mv(libc.mmap(0, bar_size[bar], mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, bar_fds[bar], 0), bar_size[bar])
|
||||
self.opened_pci_resources[pcibus] = (map_pci_range(0), None, map_pci_range(5).cast('I'))
|
||||
|
||||
try:
|
||||
self.devs.append(AMSMI(pcibus, *self.opened_pci_resources[pcibus]))
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Failed to open AM device {pcibus}: {e}")
|
||||
return
|
||||
|
||||
self.opened_pcidevs.append(pcibus)
|
||||
if DEBUG >= 2: print(f"Opened AM device {pcibus}")
|
||||
|
||||
def rescan_devs(self):
|
||||
pattern = os.path.join('/tmp', 'am_*.lock')
|
||||
for d in [f[8:-5] for f in glob.glob(pattern)]:
|
||||
if d not in self.opened_pcidevs:
|
||||
self._open_am_device(d)
|
||||
|
||||
for d in self.devs:
|
||||
if d.reg("regSCRATCH_REG7").read() != AM_VERSION:
|
||||
self.devs.remove(d)
|
||||
self.opened_pcidevs.remove(d.pcibus)
|
||||
os.system('clear')
|
||||
if DEBUG >= 2: print(f"Removed AM device {d.pcibus}")
|
||||
|
||||
def collect(self): return {d: d.smu.read_metrics() for d in self.devs}
|
||||
|
||||
def draw(self):
|
||||
terminal_width, _ = shutil.get_terminal_size()
|
||||
|
||||
dev_metrics = self.collect()
|
||||
dev_content = []
|
||||
for dev, metrics in dev_metrics.items():
|
||||
device_line = [f"PCIe device: {bold(dev.pcibus)}"] + [""]
|
||||
activity_line = [f"GFX Activity {draw_bar(metrics.SmuMetrics.AverageGfxActivity / 100, 50)}"] \
|
||||
+ [f"UCLK Activity {draw_bar(metrics.SmuMetrics.AverageUclkActivity / 100, 50)}"] + [""]
|
||||
|
||||
# draw_metrics_table(metrics, dev)
|
||||
temps_keys = [(k, name) for k, name in smu_v13_0_0.c__EA_TEMP_e__enumvalues.items()
|
||||
if k < smu_v13_0_0.TEMP_COUNT and metrics.SmuMetrics.AvgTemperature[k] != 0]
|
||||
temps_table = ["=== Temps (C) ==="] + [f"{name:<15}: {color_temp(metrics.SmuMetrics.AvgTemperature[k])}" for k, name in temps_keys]
|
||||
|
||||
voltage_keys = [(k, name) for k, name in smu_v13_0_0.c__EA_SVI_PLANE_e__enumvalues.items() if k < smu_v13_0_0.SVI_PLANE_COUNT]
|
||||
power_table = ["=== Power ==="] \
|
||||
+ [f"Fan Speed: {metrics.SmuMetrics.AvgFanRpm} RPM"] \
|
||||
+ [f"Fan Power: {metrics.SmuMetrics.AvgFanPwm} %"] \
|
||||
+ [f"Power: {metrics.SmuMetrics.AverageSocketPower}W " +
|
||||
draw_bar(metrics.SmuMetrics.AverageSocketPower / metrics.SmuMetrics.dGPU_W_MAX, 16)] \
|
||||
+ ["", "=== Voltages ==="] + [f"{name:<24}: {color_voltage(metrics.SmuMetrics.AvgVoltage[k])}" for k, name in voltage_keys]
|
||||
|
||||
frequency_table = ["=== Frequencies ===",
|
||||
f"GFXCLK Target : {metrics.SmuMetrics.AverageGfxclkFrequencyTarget} MHz",
|
||||
f"GFXCLK PreDs : {metrics.SmuMetrics.AverageGfxclkFrequencyPreDs} MHz",
|
||||
f"GFXCLK PostDs : {metrics.SmuMetrics.AverageGfxclkFrequencyPostDs} MHz",
|
||||
f"FCLK PreDs : {metrics.SmuMetrics.AverageFclkFrequencyPreDs} MHz",
|
||||
f"FCLK PostDs : {metrics.SmuMetrics.AverageFclkFrequencyPostDs} MHz",
|
||||
f"MCLK PreDs : {metrics.SmuMetrics.AverageMemclkFrequencyPreDs} MHz",
|
||||
f"MCLK PostDs : {metrics.SmuMetrics.AverageMemclkFrequencyPostDs} MHz",
|
||||
f"VCLK0 : {metrics.SmuMetrics.AverageVclk0Frequency} MHz",
|
||||
f"DCLK0 : {metrics.SmuMetrics.AverageDclk0Frequency} MHz",
|
||||
f"VCLK1 : {metrics.SmuMetrics.AverageVclk1Frequency} MHz",
|
||||
f"DCLK1 : {metrics.SmuMetrics.AverageDclk1Frequency} MHz"]
|
||||
|
||||
dev_content.append(device_line + activity_line + same_line([temps_table, power_table, frequency_table]))
|
||||
|
||||
raw_text = 'AM Monitor'.center(terminal_width) + "\n" + "=" * terminal_width + "\n\n"
|
||||
for i in range(0, len(dev_content), 2):
|
||||
if i + 1 < len(dev_content): raw_text += '\n'.join(same_line([dev_content[i], dev_content[i+1]]))
|
||||
else: raw_text += '\n'.join(dev_content[i])
|
||||
if i + 2 < len(dev_content): raw_text += "\n" + "=" * terminal_width + "\n\n"
|
||||
|
||||
sys.stdout.write(f'\033[{self.prev_lines_cnt}A')
|
||||
sys.stdout.flush()
|
||||
print(raw_text)
|
||||
|
||||
self.prev_lines_cnt = len(raw_text.splitlines()) + 2
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
os.system('clear')
|
||||
smi_ctx = SMICtx()
|
||||
while True:
|
||||
smi_ctx.rescan_devs()
|
||||
smi_ctx.draw()
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt: print("Exiting...")
|
||||
2
test/external/external_test_am.py
vendored
2
test/external/external_test_am.py
vendored
@@ -16,7 +16,7 @@ class FakePCIDev:
|
||||
|
||||
class FakeAM:
|
||||
def __init__(self):
|
||||
self.is_booting = True
|
||||
self.is_booting, self.smi_dev = True, False
|
||||
self.pcidev = FakePCIDev()
|
||||
self.vram = memoryview(bytearray(4 << 30))
|
||||
self.gmc = FakeGMC()
|
||||
|
||||
@@ -487,11 +487,11 @@ class PCIIface:
|
||||
self.pagemap = HWInterface("/proc/self/pagemap", os.O_RDONLY)
|
||||
self.bar_fds = {bar: HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource{bar}", os.O_RDWR | os.O_SYNC) for bar in [0, 2, 5]}
|
||||
|
||||
self.adev = AMDev(self.pcidev, self.pcibus, self._map_pci_range(0), dbell:=self._map_pci_range(2).cast('Q'), self._map_pci_range(5).cast('I'))
|
||||
self.adev = AMDev(self.pcibus, self._map_pci_range(0), dbell:=self._map_pci_range(2).cast('Q'), self._map_pci_range(5).cast('I'))
|
||||
self.doorbell_cpu_addr = mv_address(dbell)
|
||||
|
||||
libpciaccess.pci_device_cfg_read_u16(self.adev.pcidev, ctypes.byref(val:=ctypes.c_uint16()), libpciaccess.PCI_COMMAND)
|
||||
libpciaccess.pci_device_cfg_write_u16(self.adev.pcidev, val.value | libpciaccess.PCI_COMMAND_MASTER, libpciaccess.PCI_COMMAND)
|
||||
libpciaccess.pci_device_cfg_read_u16(self.pcidev, ctypes.byref(val:=ctypes.c_uint16()), libpciaccess.PCI_COMMAND)
|
||||
libpciaccess.pci_device_cfg_write_u16(self.pcidev, val.value | libpciaccess.PCI_COMMAND_MASTER, libpciaccess.PCI_COMMAND)
|
||||
|
||||
# TODO: this is for 7900xtx, the only tested card.
|
||||
self.props = {'simd_count': 192, 'simd_per_cu': 2, 'max_waves_per_simd': 16, 'gfx_target_version': 110000, 'max_slots_scratch_cu': 32,
|
||||
|
||||
@@ -171,7 +171,7 @@ class AMMemoryManager:
|
||||
self.adev, self.vram_size = adev, vram_size
|
||||
self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device
|
||||
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device
|
||||
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1)
|
||||
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1)
|
||||
|
||||
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping:
|
||||
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
|
||||
@@ -231,8 +231,8 @@ class AMMemoryManager:
|
||||
def pfree(self, paddr:int): self.pa_allocator.free(paddr)
|
||||
|
||||
class AMDev:
|
||||
def __init__(self, pcidev, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
|
||||
self.pcidev, self.devfmt = pcidev, devfmt
|
||||
def __init__(self, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
|
||||
self.devfmt = devfmt
|
||||
self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
|
||||
|
||||
os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions
|
||||
@@ -256,8 +256,8 @@ class AMDev:
|
||||
# To enable this, AM uses a separate boot memory that is guaranteed not to be overwritten. This physical memory is utilized for
|
||||
# all blocks that are initialized only during the initial AM boot.
|
||||
# To determine if the GPU is in the third state, AM uses regSCRATCH_REG7 as a flag.
|
||||
self.is_booting = True # During boot only boot memory can be allocated. This flag is to validate this.
|
||||
self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000001)) and (getenv("AM_RESET", 0) != 1)
|
||||
self.is_booting, self.smi_dev = True, False # During boot only boot memory can be allocated. This flag is to validate this.
|
||||
self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000002)) and (getenv("AM_RESET", 0) != 1)
|
||||
|
||||
# Memory manager & firmware
|
||||
self.mm = AMMemoryManager(self, self.vram_size)
|
||||
|
||||
@@ -102,7 +102,13 @@ class AM_GMC(AM_IP):
|
||||
if self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_STATUS").read(): raise RuntimeError(f"{ip}VM_L2_PROTECTION_FAULT_STATUS: {st:#x} {va:#x}")
|
||||
|
||||
class AM_SMU(AM_IP):
|
||||
def __init__(self, adev):
|
||||
super().__init__(adev)
|
||||
self.driver_table_paddr = self.adev.mm.palloc(0x4000, zero=not self.adev.partial_boot, boot=True)
|
||||
|
||||
def init(self):
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True)
|
||||
|
||||
for clck in [0x00000C94, 0x000204E1, 0x000105DC, 0x00050B76, 0x00070B76, 0x00040898, 0x00060898, 0x000308FD]:
|
||||
@@ -118,6 +124,11 @@ class AM_SMU(AM_IP):
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True)
|
||||
time.sleep(0.5) # 500ms
|
||||
|
||||
def read_table(self, table_t, cmd):
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True)
|
||||
return table_t.from_buffer(to_mv(self.adev.paddr2cpu(self.driver_table_paddr), ctypes.sizeof(table_t)))
|
||||
def read_metrics(self): return self.read_table(smu_v13_0_0.SmuMetricsExternal_t, smu_v13_0_0.TABLE_SMU_METRICS)
|
||||
|
||||
def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout)
|
||||
def _smu_cmn_send_msg(self, msg, param=0):
|
||||
self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg
|
||||
@@ -240,8 +251,9 @@ class AM_IH(AM_IP):
|
||||
def __init__(self, adev):
|
||||
super().__init__(adev)
|
||||
self.ring_size = 512 << 10
|
||||
self.rings = [(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0),
|
||||
(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)]
|
||||
def _alloc_ring(size): return (self.adev.mm.palloc(size, zero=not self.adev.partial_boot, boot=True),
|
||||
self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True))
|
||||
self.rings = [(*_alloc_ring(self.ring_size), "", 0), (*_alloc_ring(self.ring_size), "_RING1", 1)]
|
||||
|
||||
def interrupt_handler(self):
|
||||
_, rwptr_vm, suf, _ = self.rings[0]
|
||||
@@ -318,6 +330,9 @@ class AM_PSP(AM_IP):
|
||||
self.ring_size = 0x10000
|
||||
self.ring_paddr = self.adev.mm.palloc(self.ring_size, zero=not self.adev.partial_boot, boot=True)
|
||||
|
||||
self.max_tmr_size = 0x1300000
|
||||
self.tmr_paddr = self.adev.mm.palloc(self.max_tmr_size, align=am.PSP_TMR_ALIGNMENT, zero=not self.adev.partial_boot, boot=True)
|
||||
|
||||
def is_sos_alive(self): return self.adev.regMP0_SMN_C2PMSG_81.read() != 0x0
|
||||
def init(self):
|
||||
sos_components_load_order = [
|
||||
@@ -362,7 +377,7 @@ class AM_PSP(AM_IP):
|
||||
# Load TOC and calculate TMR size
|
||||
self._prep_msg1(fwm:=self.adev.fw.sos_fw[am.PSP_FW_TYPE_PSP_TOC])
|
||||
self.tmr_size = self._load_toc_cmd(len(fwm)).resp.tmr_size
|
||||
self.tmr_paddr = self.adev.mm.palloc(self.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True)
|
||||
assert self.tmr_size <= self.max_tmr_size
|
||||
|
||||
def _ring_create(self):
|
||||
# If the ring is already created, destroy it
|
||||
|
||||
Reference in New Issue
Block a user