mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
am: AMRegister refactor (#9572)
This commit is contained in:
@@ -1,35 +1,26 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, collections, time, dataclasses, pathlib, fcntl, os, importlib
|
||||
import ctypes, collections, time, dataclasses, functools, pathlib, fcntl, os, importlib
|
||||
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
|
||||
from tinygrad.runtime.autogen.am import am, mp_11_0
|
||||
from tinygrad.runtime.support.amd import AMDRegBase, collect_registers
|
||||
from tinygrad.runtime.support.allocator import TLSFAllocator
|
||||
from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
|
||||
|
||||
AM_DEBUG = getenv("AM_DEBUG", 0)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AMRegister:
|
||||
adev:AMDev; reg_off:int; reg_fields:dict[str, tuple[int, int]] # noqa: E702
|
||||
class AMRegister(AMDRegBase):
|
||||
adev:AMDev; hwip:int # noqa: E702
|
||||
|
||||
def _parse_kwargs(self, **kwargs):
|
||||
mask, values = 0xffffffff, 0
|
||||
for k, v in kwargs.items():
|
||||
if k not in self.reg_fields: raise ValueError(f"Unknown register field: {k}. {self.reg_fields.keys()}")
|
||||
m, s = self.reg_fields[k]
|
||||
if v & (m>>s) != v: raise ValueError(f"Value {v} for {k} is out of range {m=} {s=}")
|
||||
mask &= ~m
|
||||
values |= v << s
|
||||
return mask, values
|
||||
@property
|
||||
def addr(self): return self.adev.regs_offset[self.hwip][0][self.segment] + self.offset
|
||||
|
||||
def build(self, **kwargs) -> int: return self._parse_kwargs(**kwargs)[1]
|
||||
def read(self): return self.adev.rreg(self.addr)
|
||||
def read_bitfields(self) -> dict[str, int]: return self.decode(self.read())
|
||||
|
||||
def update(self, **kwargs): self.write(value=self.read(), **kwargs)
|
||||
def write(self, _am_val:int=0, **kwargs): self.adev.wreg(self.addr, _am_val | self.encode(**kwargs))
|
||||
|
||||
def write(self, value=0, **kwargs):
|
||||
mask, values = self._parse_kwargs(**kwargs)
|
||||
self.adev.wreg(self.reg_off, (value & mask) | values)
|
||||
|
||||
def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0]
|
||||
def update(self, **kwargs): self.write(self.encode(**{**self.read_bitfields(), **kwargs}))
|
||||
|
||||
class AMFirmware:
|
||||
def __init__(self, adev):
|
||||
@@ -327,8 +318,6 @@ class AMDev:
|
||||
def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
|
||||
def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
|
||||
|
||||
def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg]
|
||||
|
||||
def reg(self, reg:str) -> AMRegister: return self.__dict__[reg]
|
||||
|
||||
def rreg(self, reg:int) -> int:
|
||||
@@ -358,7 +347,7 @@ class AMDev:
|
||||
for _ in range(timeout):
|
||||
if ((rval:=reg.read()) & mask) == value: return rval
|
||||
time.sleep(0.001)
|
||||
raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
|
||||
raise RuntimeError(f'wait_reg timeout reg=0x{reg.addr:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
|
||||
|
||||
def _run_discovery(self):
|
||||
# NOTE: Fixed register to query memory size without known ip bases to find the discovery table.
|
||||
@@ -403,14 +392,5 @@ class AMDev:
|
||||
("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP)),
|
||||
("HDP", self._ip_module("hdp", am.HDP_HWIP))]
|
||||
|
||||
for base, module in mods:
|
||||
rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm
|
||||
reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX"))
|
||||
reg_fields: dict[str, dict[str, tuple]] = collections.defaultdict(dict)
|
||||
for k, val in module.__dict__.items():
|
||||
if k.endswith("_MASK") and ((rname:=k.split("__")[0]) in reg_names):
|
||||
reg_fields[rname][k[2+len(rname):-5].lower()] = (val, module.__dict__.get(f"{k[:-5]}__SHIFT", val.bit_length() - 1))
|
||||
|
||||
for k, regval in module.__dict__.items():
|
||||
if k.startswith(rpref) and not k.endswith("_BASE_IDX") and (base_idx:=getattr(module, f"{k}_BASE_IDX", None)) is not None:
|
||||
setattr(self, k, AMRegister(self, self.ip_base(base, 0, base_idx) + regval, reg_fields.get(k[len(rpref):], {})))
|
||||
for ip, module in mods:
|
||||
self.__dict__.update(collect_registers(module, cls=functools.partial(AMRegister, adev=self, hwip=getattr(am, f"{ip}_HWIP"))))
|
||||
|
||||
@@ -174,7 +174,7 @@ class AM_SMU(AM_IP):
|
||||
class AM_GFX(AM_IP):
|
||||
def init_hw(self):
|
||||
# Wait for RLC autoload to complete
|
||||
while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read(bootload_complete=1) != 0: pass
|
||||
while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] != 0: pass
|
||||
|
||||
self._config_gfx_rs64()
|
||||
self.adev.gmc.init_hub("GC")
|
||||
@@ -220,17 +220,17 @@ class AM_GFX(AM_IP):
|
||||
mqd = self.adev.mm.valloc(0x1000, uncached=True, contigous=True)
|
||||
|
||||
mqd_struct = am.struct_v11_compute_mqd(header=0xC0310800, cp_mqd_base_addr_lo=lo32(mqd.va_addr), cp_mqd_base_addr_hi=hi32(mqd.va_addr),
|
||||
cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.build(preload_size=0x55, preload_req=1),
|
||||
cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.encode(preload_size=0x55, preload_req=1),
|
||||
cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111,
|
||||
cp_hqd_pq_base_lo=lo32(ring_addr>>8), cp_hqd_pq_base_hi=hi32(ring_addr>>8),
|
||||
cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr),
|
||||
cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr),
|
||||
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.build(doorbell_offset=doorbell*2, doorbell_en=1),
|
||||
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2),
|
||||
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.build(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
|
||||
cp_mqd_control=self.adev.regCP_MQD_CONTROL.build(priv_state=1), cp_hqd_vmid=0,
|
||||
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.encode(doorbell_offset=doorbell*2, doorbell_en=1),
|
||||
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.encode(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2),
|
||||
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.encode(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
|
||||
cp_mqd_control=self.adev.regCP_MQD_CONTROL.encode(priv_state=1), cp_hqd_vmid=0,
|
||||
cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8),
|
||||
cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.build(eop_size=(eop_size//4).bit_length()-2))
|
||||
cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.encode(eop_size=(eop_size//4).bit_length()-2))
|
||||
|
||||
# Copy mqd into memory
|
||||
ctypes.memmove(self.adev.paddr2cpu(mqd.paddrs[0][0]), ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct))
|
||||
@@ -239,7 +239,7 @@ class AM_GFX(AM_IP):
|
||||
self._grbm_select(me=1, pipe=pipe, queue=queue)
|
||||
|
||||
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
|
||||
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.reg_off, self.adev.regCP_HQD_PQ_WPTR_HI.reg_off + 1)):
|
||||
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr, self.adev.regCP_HQD_PQ_WPTR_HI.addr + 1)):
|
||||
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
|
||||
self.adev.regCP_HQD_ACTIVE.write(0x1)
|
||||
|
||||
@@ -315,7 +315,7 @@ class AM_IH(AM_IP):
|
||||
_, rwptr_vm, suf, _ = self.rings[0]
|
||||
wptr = to_mv(self.adev.paddr2cpu(rwptr_vm), 8).cast('Q')[0]
|
||||
|
||||
if self.adev.reg(f"regIH_RB_WPTR{suf}").read(rb_overflow=1):
|
||||
if self.adev.reg(f"regIH_RB_WPTR{suf}").read_bitfields()['rb_overflow']:
|
||||
self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0)
|
||||
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=1)
|
||||
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=0)
|
||||
|
||||
26
tinygrad/runtime/support/amd.py
Normal file
26
tinygrad/runtime/support/amd.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from math import log2
|
||||
from tinygrad.helpers import getbits
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AMDRegBase:
|
||||
name: str
|
||||
offset: int
|
||||
segment: int
|
||||
fields: dict[str, tuple[int, int]]
|
||||
def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0)
|
||||
def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()}
|
||||
|
||||
def collect_registers(module, cls=AMDRegBase) -> dict[str, AMDRegBase]:
|
||||
def _split_name(name): return name[:(pos:=next((i for i,c in enumerate(name) if c.isupper()), len(name)))], name[pos:]
|
||||
offsets = {k:v for k,v in module.__dict__.items() if _split_name(k)[0] in {'reg', 'mm'} and not k.endswith('_BASE_IDX')}
|
||||
bases = {k[:-len('_BASE_IDX')]:v for k,v in module.__dict__.items() if _split_name(k)[0] in {'reg', 'mm'} and k.endswith('_BASE_IDX')}
|
||||
fields: defaultdict[str, dict[str, tuple[int, int]]] = defaultdict(dict)
|
||||
for field_name,field_mask in module.__dict__.items():
|
||||
if not ('__' in field_name and field_name.endswith('_MASK')): continue
|
||||
reg_name, reg_field_name = field_name[:-len('_MASK')].split('__')
|
||||
fields[reg_name][reg_field_name.lower()] = (int(log2(field_mask & -field_mask)), int(log2(field_mask)))
|
||||
# NOTE: Some registers like regGFX_IMU_FUSESTRAP in gc_11_0_0 are missing base idx, just skip them
|
||||
return {reg:cls(name=reg, offset=off, segment=bases[reg], fields=fields[_split_name(reg)[1]]) for reg,off in offsets.items() if reg in bases}
|
||||
Reference in New Issue
Block a user