From 2c32126fc8da2b28c4255d7ee6dedfafc8b2e15d Mon Sep 17 00:00:00 2001 From: uuuvn <83587632+uuuvn@users.noreply.github.com> Date: Tue, 25 Mar 2025 22:52:40 +0500 Subject: [PATCH] am: AMRegister refactor (#9572) --- tinygrad/runtime/support/am/amdev.py | 46 ++++++++-------------------- tinygrad/runtime/support/am/ip.py | 18 +++++------ tinygrad/runtime/support/amd.py | 26 ++++++++++++++++ 3 files changed, 48 insertions(+), 42 deletions(-) create mode 100644 tinygrad/runtime/support/amd.py diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 4316d55aac..e29323a4bf 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -1,35 +1,26 @@ from __future__ import annotations -import ctypes, collections, time, dataclasses, pathlib, fcntl, os, importlib +import ctypes, collections, time, dataclasses, functools, pathlib, fcntl, os, importlib from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp from tinygrad.runtime.autogen.am import am, mp_11_0 +from tinygrad.runtime.support.amd import AMDRegBase, collect_registers from tinygrad.runtime.support.allocator import TLSFAllocator from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA AM_DEBUG = getenv("AM_DEBUG", 0) @dataclasses.dataclass(frozen=True) -class AMRegister: - adev:AMDev; reg_off:int; reg_fields:dict[str, tuple[int, int]] # noqa: E702 +class AMRegister(AMDRegBase): + adev:AMDev; hwip:int # noqa: E702 - def _parse_kwargs(self, **kwargs): - mask, values = 0xffffffff, 0 - for k, v in kwargs.items(): - if k not in self.reg_fields: raise ValueError(f"Unknown register field: {k}. {self.reg_fields.keys()}") - m, s = self.reg_fields[k] - if v & (m>>s) != v: raise ValueError(f"Value {v} for {k} is out of range {m=} {s=}") - mask &= ~m - values |= v << s - return mask, values + @property + def addr(self): return self.adev.regs_offset[self.hwip][0][self.segment] + self.offset - def build(self, **kwargs) -> int: return self._parse_kwargs(**kwargs)[1] + def read(self): return self.adev.rreg(self.addr) + def read_bitfields(self) -> dict[str, int]: return self.decode(self.read()) - def update(self, **kwargs): self.write(value=self.read(), **kwargs) + def write(self, _am_val:int=0, **kwargs): self.adev.wreg(self.addr, _am_val | self.encode(**kwargs)) - def write(self, value=0, **kwargs): - mask, values = self._parse_kwargs(**kwargs) - self.adev.wreg(self.reg_off, (value & mask) | values) - - def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0] + def update(self, **kwargs): self.write(self.encode(**{**self.read_bitfields(), **kwargs})) class AMFirmware: def __init__(self, adev): @@ -327,8 +318,6 @@ class AMDev: def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr - def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg] - def reg(self, reg:str) -> AMRegister: return self.__dict__[reg] def rreg(self, reg:int) -> int: @@ -358,7 +347,7 @@ class AMDev: for _ in range(timeout): if ((rval:=reg.read()) & mask) == value: return rval time.sleep(0.001) - raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}') + raise RuntimeError(f'wait_reg timeout reg=0x{reg.addr:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}') def _run_discovery(self): # NOTE: Fixed register to query memory size without known ip bases to find the discovery table. @@ -403,14 +392,5 @@ class AMDev: ("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP)), ("HDP", self._ip_module("hdp", am.HDP_HWIP))] - for base, module in mods: - rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm - reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX")) - reg_fields: dict[str, dict[str, tuple]] = collections.defaultdict(dict) - for k, val in module.__dict__.items(): - if k.endswith("_MASK") and ((rname:=k.split("__")[0]) in reg_names): - reg_fields[rname][k[2+len(rname):-5].lower()] = (val, module.__dict__.get(f"{k[:-5]}__SHIFT", val.bit_length() - 1)) - - for k, regval in module.__dict__.items(): - if k.startswith(rpref) and not k.endswith("_BASE_IDX") and (base_idx:=getattr(module, f"{k}_BASE_IDX", None)) is not None: - setattr(self, k, AMRegister(self, self.ip_base(base, 0, base_idx) + regval, reg_fields.get(k[len(rpref):], {}))) + for ip, module in mods: + self.__dict__.update(collect_registers(module, cls=functools.partial(AMRegister, adev=self, hwip=getattr(am, f"{ip}_HWIP")))) diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 23e72a8e8b..593fac6fcd 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -174,7 +174,7 @@ class AM_SMU(AM_IP): class AM_GFX(AM_IP): def init_hw(self): # Wait for RLC autoload to complete - while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read(bootload_complete=1) != 0: pass + while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] != 0: pass self._config_gfx_rs64() self.adev.gmc.init_hub("GC") @@ -220,17 +220,17 @@ class AM_GFX(AM_IP): mqd = self.adev.mm.valloc(0x1000, uncached=True, contigous=True) mqd_struct = am.struct_v11_compute_mqd(header=0xC0310800, cp_mqd_base_addr_lo=lo32(mqd.va_addr), cp_mqd_base_addr_hi=hi32(mqd.va_addr), - cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.build(preload_size=0x55, preload_req=1), + cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.encode(preload_size=0x55, preload_req=1), cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111, cp_hqd_pq_base_lo=lo32(ring_addr>>8), cp_hqd_pq_base_hi=hi32(ring_addr>>8), cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr), cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr), - cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.build(doorbell_offset=doorbell*2, doorbell_en=1), - cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2), - cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.build(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000, - cp_mqd_control=self.adev.regCP_MQD_CONTROL.build(priv_state=1), cp_hqd_vmid=0, + cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.encode(doorbell_offset=doorbell*2, doorbell_en=1), + cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.encode(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2), + cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.encode(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000, + cp_mqd_control=self.adev.regCP_MQD_CONTROL.encode(priv_state=1), cp_hqd_vmid=0, cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8), - cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.build(eop_size=(eop_size//4).bit_length()-2)) + cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.encode(eop_size=(eop_size//4).bit_length()-2)) # Copy mqd into memory ctypes.memmove(self.adev.paddr2cpu(mqd.paddrs[0][0]), ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)) @@ -239,7 +239,7 @@ class AM_GFX(AM_IP): self._grbm_select(me=1, pipe=pipe, queue=queue) mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I') - for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.reg_off, self.adev.regCP_HQD_PQ_WPTR_HI.reg_off + 1)): + for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr, self.adev.regCP_HQD_PQ_WPTR_HI.addr + 1)): self.adev.wreg(reg, mqd_st_mv[0x80 + i]) self.adev.regCP_HQD_ACTIVE.write(0x1) @@ -315,7 +315,7 @@ class AM_IH(AM_IP): _, rwptr_vm, suf, _ = self.rings[0] wptr = to_mv(self.adev.paddr2cpu(rwptr_vm), 8).cast('Q')[0] - if self.adev.reg(f"regIH_RB_WPTR{suf}").read(rb_overflow=1): + if self.adev.reg(f"regIH_RB_WPTR{suf}").read_bitfields()['rb_overflow']: self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0) self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=1) self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=0) diff --git a/tinygrad/runtime/support/amd.py b/tinygrad/runtime/support/amd.py new file mode 100644 index 0000000000..455fe8f903 --- /dev/null +++ b/tinygrad/runtime/support/amd.py @@ -0,0 +1,26 @@ +import functools +from collections import defaultdict +from dataclasses import dataclass +from math import log2 +from tinygrad.helpers import getbits + +@dataclass(frozen=True) +class AMDRegBase: + name: str + offset: int + segment: int + fields: dict[str, tuple[int, int]] + def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0) + def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()} + +def collect_registers(module, cls=AMDRegBase) -> dict[str, AMDRegBase]: + def _split_name(name): return name[:(pos:=next((i for i,c in enumerate(name) if c.isupper()), len(name)))], name[pos:] + offsets = {k:v for k,v in module.__dict__.items() if _split_name(k)[0] in {'reg', 'mm'} and not k.endswith('_BASE_IDX')} + bases = {k[:-len('_BASE_IDX')]:v for k,v in module.__dict__.items() if _split_name(k)[0] in {'reg', 'mm'} and k.endswith('_BASE_IDX')} + fields: defaultdict[str, dict[str, tuple[int, int]]] = defaultdict(dict) + for field_name,field_mask in module.__dict__.items(): + if not ('__' in field_name and field_name.endswith('_MASK')): continue + reg_name, reg_field_name = field_name[:-len('_MASK')].split('__') + fields[reg_name][reg_field_name.lower()] = (int(log2(field_mask & -field_mask)), int(log2(field_mask))) + # NOTE: Some registers like regGFX_IMU_FUSESTRAP in gc_11_0_0 are missing base idx, just skip them + return {reg:cls(name=reg, offset=off, segment=bases[reg], fields=fields[_split_name(reg)[1]]) for reg,off in offsets.items() if reg in bases}