am: AMRegister refactor (#9572)

This commit is contained in:
uuuvn
2025-03-25 22:52:40 +05:00
committed by GitHub
parent cddd750d68
commit 2c32126fc8
3 changed files with 48 additions and 42 deletions

View File

@@ -1,35 +1,26 @@
from __future__ import annotations
import ctypes, collections, time, dataclasses, pathlib, fcntl, os, importlib
import ctypes, collections, time, dataclasses, functools, pathlib, fcntl, os, importlib
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
from tinygrad.runtime.autogen.am import am, mp_11_0
from tinygrad.runtime.support.amd import AMDRegBase, collect_registers
from tinygrad.runtime.support.allocator import TLSFAllocator
from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
AM_DEBUG = getenv("AM_DEBUG", 0)
@dataclasses.dataclass(frozen=True)
class AMRegister:
adev:AMDev; reg_off:int; reg_fields:dict[str, tuple[int, int]] # noqa: E702
class AMRegister(AMDRegBase):
adev:AMDev; hwip:int # noqa: E702
def _parse_kwargs(self, **kwargs):
mask, values = 0xffffffff, 0
for k, v in kwargs.items():
if k not in self.reg_fields: raise ValueError(f"Unknown register field: {k}. {self.reg_fields.keys()}")
m, s = self.reg_fields[k]
if v & (m>>s) != v: raise ValueError(f"Value {v} for {k} is out of range {m=} {s=}")
mask &= ~m
values |= v << s
return mask, values
@property
def addr(self): return self.adev.regs_offset[self.hwip][0][self.segment] + self.offset
def build(self, **kwargs) -> int: return self._parse_kwargs(**kwargs)[1]
def read(self): return self.adev.rreg(self.addr)
def read_bitfields(self) -> dict[str, int]: return self.decode(self.read())
def update(self, **kwargs): self.write(value=self.read(), **kwargs)
def write(self, _am_val:int=0, **kwargs): self.adev.wreg(self.addr, _am_val | self.encode(**kwargs))
def write(self, value=0, **kwargs):
mask, values = self._parse_kwargs(**kwargs)
self.adev.wreg(self.reg_off, (value & mask) | values)
def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0]
def update(self, **kwargs): self.write(self.encode(**{**self.read_bitfields(), **kwargs}))
class AMFirmware:
def __init__(self, adev):
@@ -327,8 +318,6 @@ class AMDev:
def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg]
def reg(self, reg:str) -> AMRegister: return self.__dict__[reg]
def rreg(self, reg:int) -> int:
@@ -358,7 +347,7 @@ class AMDev:
for _ in range(timeout):
if ((rval:=reg.read()) & mask) == value: return rval
time.sleep(0.001)
raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
raise RuntimeError(f'wait_reg timeout reg=0x{reg.addr:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
def _run_discovery(self):
# NOTE: Fixed register to query memory size without known ip bases to find the discovery table.
@@ -403,14 +392,5 @@ class AMDev:
("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP)),
("HDP", self._ip_module("hdp", am.HDP_HWIP))]
for base, module in mods:
rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm
reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX"))
reg_fields: dict[str, dict[str, tuple]] = collections.defaultdict(dict)
for k, val in module.__dict__.items():
if k.endswith("_MASK") and ((rname:=k.split("__")[0]) in reg_names):
reg_fields[rname][k[2+len(rname):-5].lower()] = (val, module.__dict__.get(f"{k[:-5]}__SHIFT", val.bit_length() - 1))
for k, regval in module.__dict__.items():
if k.startswith(rpref) and not k.endswith("_BASE_IDX") and (base_idx:=getattr(module, f"{k}_BASE_IDX", None)) is not None:
setattr(self, k, AMRegister(self, self.ip_base(base, 0, base_idx) + regval, reg_fields.get(k[len(rpref):], {})))
for ip, module in mods:
self.__dict__.update(collect_registers(module, cls=functools.partial(AMRegister, adev=self, hwip=getattr(am, f"{ip}_HWIP"))))

View File

@@ -174,7 +174,7 @@ class AM_SMU(AM_IP):
class AM_GFX(AM_IP):
def init_hw(self):
# Wait for RLC autoload to complete
while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read(bootload_complete=1) != 0: pass
while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] != 0: pass
self._config_gfx_rs64()
self.adev.gmc.init_hub("GC")
@@ -220,17 +220,17 @@ class AM_GFX(AM_IP):
mqd = self.adev.mm.valloc(0x1000, uncached=True, contigous=True)
mqd_struct = am.struct_v11_compute_mqd(header=0xC0310800, cp_mqd_base_addr_lo=lo32(mqd.va_addr), cp_mqd_base_addr_hi=hi32(mqd.va_addr),
cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.build(preload_size=0x55, preload_req=1),
cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.encode(preload_size=0x55, preload_req=1),
cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111,
cp_hqd_pq_base_lo=lo32(ring_addr>>8), cp_hqd_pq_base_hi=hi32(ring_addr>>8),
cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr),
cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr),
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.build(doorbell_offset=doorbell*2, doorbell_en=1),
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2),
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.build(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
cp_mqd_control=self.adev.regCP_MQD_CONTROL.build(priv_state=1), cp_hqd_vmid=0,
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.encode(doorbell_offset=doorbell*2, doorbell_en=1),
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.encode(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2),
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.encode(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
cp_mqd_control=self.adev.regCP_MQD_CONTROL.encode(priv_state=1), cp_hqd_vmid=0,
cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8),
cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.build(eop_size=(eop_size//4).bit_length()-2))
cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.encode(eop_size=(eop_size//4).bit_length()-2))
# Copy mqd into memory
ctypes.memmove(self.adev.paddr2cpu(mqd.paddrs[0][0]), ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct))
@@ -239,7 +239,7 @@ class AM_GFX(AM_IP):
self._grbm_select(me=1, pipe=pipe, queue=queue)
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.reg_off, self.adev.regCP_HQD_PQ_WPTR_HI.reg_off + 1)):
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr, self.adev.regCP_HQD_PQ_WPTR_HI.addr + 1)):
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
self.adev.regCP_HQD_ACTIVE.write(0x1)
@@ -315,7 +315,7 @@ class AM_IH(AM_IP):
_, rwptr_vm, suf, _ = self.rings[0]
wptr = to_mv(self.adev.paddr2cpu(rwptr_vm), 8).cast('Q')[0]
if self.adev.reg(f"regIH_RB_WPTR{suf}").read(rb_overflow=1):
if self.adev.reg(f"regIH_RB_WPTR{suf}").read_bitfields()['rb_overflow']:
self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0)
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=1)
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=0)

View File

@@ -0,0 +1,26 @@
import functools
from collections import defaultdict
from dataclasses import dataclass
from math import log2
from tinygrad.helpers import getbits
@dataclass(frozen=True)
class AMDRegBase:
name: str
offset: int
segment: int
fields: dict[str, tuple[int, int]]
def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0)
def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()}
def collect_registers(module, cls=AMDRegBase) -> dict[str, AMDRegBase]:
def _split_name(name): return name[:(pos:=next((i for i,c in enumerate(name) if c.isupper()), len(name)))], name[pos:]
offsets = {k:v for k,v in module.__dict__.items() if _split_name(k)[0] in {'reg', 'mm'} and not k.endswith('_BASE_IDX')}
bases = {k[:-len('_BASE_IDX')]:v for k,v in module.__dict__.items() if _split_name(k)[0] in {'reg', 'mm'} and k.endswith('_BASE_IDX')}
fields: defaultdict[str, dict[str, tuple[int, int]]] = defaultdict(dict)
for field_name,field_mask in module.__dict__.items():
if not ('__' in field_name and field_name.endswith('_MASK')): continue
reg_name, reg_field_name = field_name[:-len('_MASK')].split('__')
fields[reg_name][reg_field_name.lower()] = (int(log2(field_mask & -field_mask)), int(log2(field_mask)))
# NOTE: Some registers like regGFX_IMU_FUSESTRAP in gc_11_0_0 are missing base idx, just skip them
return {reg:cls(name=reg, offset=off, segment=bases[reg], fields=fields[_split_name(reg)[1]]) for reg,off in offsets.items() if reg in bases}