amd: support xcc in regs (#11670)

* amd: support xcc in regs

* mockamd

* typong
This commit is contained in:
nimlgen
2025-08-14 21:20:11 +03:00
committed by GitHub
parent f399d0d75d
commit 4176b24264
6 changed files with 33 additions and 32 deletions

View File

@@ -36,7 +36,8 @@ def main():
reg_names = {}
dev = PCIIface(None, 0)
for x, y in dev.dev_impl.__dict__.items():
if isinstance(y, AMRegister): reg_names[y.addr] = x
if isinstance(y, AMRegister):
for inst, addr in y.addr.keys(): reg_names[addr] = f"{x}, xcc={inst}"
with open(sys.argv[1], 'r') as f:
log_content = log_content_them = f.read()

View File

@@ -87,16 +87,19 @@ class AMDDriver(VirtDriver):
functools.partial(TextFileDesc, text=gpu_props.format(drm_render_minor=gpu_id))),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0',
functools.partial(DirFileDesc, child_names=[str(am.GC_HWID), str(am.SDMA0_HWID), str(am.NBIF_HWID)])),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}', functools.partial(DirFileDesc, child_names=['0'])),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/major', functools.partial(TextFileDesc, text='11')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/minor', functools.partial(TextFileDesc, text='0')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/revision', functools.partial(TextFileDesc, text='0')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/base_addr',
functools.partial(TextFileDesc, text='0x00001260\n0x0000A000\n0x0001C000\n0x02402C00')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}', functools.partial(DirFileDesc, child_names=['0'])),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/major', functools.partial(TextFileDesc, text='6')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/minor', functools.partial(TextFileDesc, text='0')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/revision', functools.partial(TextFileDesc, text='0')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/base_addr',
functools.partial(TextFileDesc, text='0x00001260\n0x0000A000\n0x0001C000\n0x02402C00')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}', functools.partial(DirFileDesc, child_names=['0'])),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/major', functools.partial(TextFileDesc, text='4')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/minor', functools.partial(TextFileDesc, text='3')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/revision', functools.partial(TextFileDesc, text='0')),

View File

@@ -45,12 +45,12 @@ class AMDComputeQueue(HWQueue):
def wreg(self, reg:AMDReg, *args:sint, **kwargs:int):
if bool(args) == bool(kwargs): raise RuntimeError('One (and only one) of *args or **kwargs must be specified')
if self.pm4.PACKET3_SET_SH_REG_START <= reg.addr < self.pm4.PACKET3_SET_SH_REG_END:
if self.pm4.PACKET3_SET_SH_REG_START <= reg.addr[0] < self.pm4.PACKET3_SET_SH_REG_END:
set_packet, set_packet_start = self.pm4.PACKET3_SET_SH_REG, self.pm4.PACKET3_SET_SH_REG_START
elif self.pm4.PACKET3_SET_UCONFIG_REG_START <= reg.addr < self.pm4.PACKET3_SET_UCONFIG_REG_START + 2**16-1:
elif self.pm4.PACKET3_SET_UCONFIG_REG_START <= reg.addr[0] < self.pm4.PACKET3_SET_UCONFIG_REG_START + 2**16-1:
set_packet, set_packet_start = self.pm4.PACKET3_SET_UCONFIG_REG, self.pm4.PACKET3_SET_UCONFIG_REG_START
else: raise RuntimeError(f'Cannot set {reg.name} ({reg.addr}) via pm4 packet')
self.pkt3(set_packet, reg.addr - set_packet_start, *(args or (reg.encode(**kwargs),)))
else: raise RuntimeError(f'Cannot set {reg.name} ({reg.addr[0]}) via pm4 packet')
self.pkt3(set_packet, reg.addr[0] - set_packet_start, *(args or (reg.encode(**kwargs),)))
@contextlib.contextmanager
def pred_exec(self, xcc_mask:int):
@@ -119,8 +119,8 @@ class AMDComputeQueue(HWQueue):
def memory_barrier(self):
pf = '' if self.nbio.version[0] == 2 else '0' if self.nbio.version[:2] != (7, 11) else '1'
self.wait_reg_mem(reg_req=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_REQ').addr,
reg_done=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr, value=0xffffffff)
self.wait_reg_mem(reg_req=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_REQ').addr[0],
reg_done=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff)
self.acquire_mem()
return self
@@ -190,17 +190,17 @@ class AMDComputeQueue(HWQueue):
self.wreg(self.gc.regGRBM_GFX_INDEX, se_index=se, instance_broadcast_writes=1)
# Wait for FINISH_PENDING==0
self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_EQ),
self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_pending'), 4)
self.gc.regSQ_THREAD_TRACE_STATUS.addr[0], 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_pending'), 4)
# Wait for FINISH_DONE!=0
self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_NEQ),
self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_done'), 4)
self.gc.regSQ_THREAD_TRACE_STATUS.addr[0], 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_done'), 4)
# Disable SQTT
self.sqtt_config(tracing=False)
# Wait for BUSY==0
self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_EQ),
self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('busy'), 4)
self.gc.regSQ_THREAD_TRACE_STATUS.addr[0], 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('busy'), 4)
# Copy WPTR to memory (src_sel = perf, dst_sel = tc_l2, wr_confirm = True)
self.pkt3(self.pm4.PACKET3_COPY_DATA, 1 << 20 | 2 << 8 | 4, self.gc.regSQ_THREAD_TRACE_WPTR.addr, 0, *data64_le(wptrs.va_addr+(se*4)))
self.pkt3(self.pm4.PACKET3_COPY_DATA, 1 << 20 | 2 << 8 | 4, self.gc.regSQ_THREAD_TRACE_WPTR.addr[0], 0, *data64_le(wptrs.va_addr+(se*4)))
# Restore global broadcasting
self.wreg(self.gc.regGRBM_GFX_INDEX, se_broadcast_writes=1, sa_broadcast_writes=1, instance_broadcast_writes=1)
self.spi_config(tracing=False)
@@ -539,10 +539,10 @@ class KFDIface:
self.props = {(p:=l.split())[0]: int(p[1]) for l in FileIOInterface(f"{kfd_topo_path}/{KFDIface.gpus[device_id]}/properties").read().splitlines()}
ip_base = f"/sys/class/drm/renderD{self.props['drm_render_minor']}/device/ip_discovery/die/0"
id2ip = {am.GC_HWID: am.GC_HWIP, am.SDMA0_HWID: am.SDMA0_HWIP, am.NBIF_HWID: am.NBIF_HWIP}
self.ip_versions = {id2ip[int(hwid)]:tuple(int(FileIOInterface(f'{ip_base}/{hwid}/0/{part}').read()) for part in ['major', 'minor', 'revision'])
for hwid in FileIOInterface(ip_base).listdir() if hwid.isnumeric() and int(hwid) in id2ip}
self.ip_offsets = {id2ip[int(hwid)]:tuple(int(x, 16) for x in FileIOInterface(f'{ip_base}/{hwid}/0/base_addr').read().splitlines())
for hwid in FileIOInterface(ip_base).listdir() if hwid.isnumeric() and int(hwid) in id2ip}
ip_hw = [(id2ip[int(hwid)], int(hwid)) for hwid in FileIOInterface(ip_base).listdir() if hwid.isnumeric() and int(hwid) in id2ip]
self.ip_versions = {ip:tuple(int(FileIOInterface(f'{ip_base}/{hw}/0/{part}').read()) for part in ['major','minor','revision']) for ip,hw in ip_hw}
self.ip_offsets = {ip:{int(i):tuple(int(x, 16) for x in FileIOInterface(f'{ip_base}/{hw}/{i}/base_addr').read().splitlines())
for i in FileIOInterface(f'{ip_base}/{hw}').listdir()} for ip,hw in ip_hw }
self.drm_fd = FileIOInterface(f"/dev/dri/renderD{self.props['drm_render_minor']}", os.O_RDWR)
kfd.AMDKFD_IOC_ACQUIRE_VM(KFDIface.kfd, drm_fd=self.drm_fd.fd, gpu_id=self.gpu_id)
@@ -653,8 +653,7 @@ class PCIIface(PCIIfaceBase):
def _setup_adev(self, name, vram:MMIOInterface, doorbell:MMIOInterface, mmio:MMIOInterface, dma_regions:list[tuple[int, MMIOInterface]]|None=None):
self.dev_impl:AMDev = AMDev(name, vram, doorbell, mmio, dma_regions)
self.ip_versions = self.dev_impl.ip_ver
self.ip_offsets = {hwip: tuple(instances[0]) for hwip,instances in self.dev_impl.regs_offset.items()}
self.ip_offsets, self.ip_versions = self.dev_impl.regs_offset, self.dev_impl.ip_ver
gfxver = int(f"{self.dev_impl.ip_ver[am.GC_HWIP][0]:02d}{self.dev_impl.ip_ver[am.GC_HWIP][1]:02d}{self.dev_impl.ip_ver[am.GC_HWIP][2]:02d}")
array_count = self.dev_impl.gc_info.gc_num_sa_per_se * self.dev_impl.gc_info.gc_num_se
@@ -762,7 +761,7 @@ class AMDDevice(HCQCompiled):
nbio_name = 'nbio' if self.target[0] < 12 else 'nbif'
nbio_pad = (0,) if self.target[0] == 9 else ()
self.nbio = AMDIP(nbio_name, self.iface.ip_versions[am.NBIF_HWIP], nbio_pad+self.iface.ip_offsets[am.NBIF_HWIP])
self.nbio = AMDIP(nbio_name, self.iface.ip_versions[am.NBIF_HWIP], {i:nbio_pad+x for i,x in self.iface.ip_offsets[am.NBIF_HWIP].items()})
self.compute_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_COMPUTE, 0x2000 if self.is_usb() else (16 << 20), eop_buffer_size=0x1000,
ctx_save_restore_size=0 if self.is_am() else wg_data_size + ctl_stack_size, ctl_stack_size=ctl_stack_size, debug_memory_size=debug_memory_size)

View File

@@ -10,16 +10,16 @@ from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU
AM_DEBUG = getenv("AM_DEBUG", 0)
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass
class AMRegister(AMDReg):
adev:AMDev
def read(self): return self.adev.rreg(self.addr)
def read_bitfields(self) -> dict[str, int]: return self.decode(self.read())
def read(self, inst=0): return self.adev.rreg(self.addr[inst])
def read_bitfields(self, inst=0) -> dict[str, int]: return self.decode(self.read(inst=inst))
def write(self, _am_val:int=0, **kwargs): self.adev.wreg(self.addr, _am_val | self.encode(**kwargs))
def write(self, _am_val:int=0, inst=0, **kwargs): self.adev.wreg(self.addr[inst], _am_val | self.encode(**kwargs))
def update(self, **kwargs): self.write(self.read() & ~self.fields_mask(*kwargs.keys()), **kwargs)
def update(self, inst=0, **kwargs): self.write(self.read(inst=inst) & ~self.fields_mask(*kwargs.keys()), inst=inst, **kwargs)
class AMFirmware:
def __init__(self, adev):
@@ -254,6 +254,6 @@ class AMDev(PCIDevImplBase):
("nbio" if self.ip_ver[am.GC_HWIP] < (12,0,0) else "nbif", am.NBIO_HWIP)]
for prefix, hwip in mods:
self.__dict__.update(import_asic_regs(prefix, self.ip_ver[hwip], cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[hwip][0])))
self.__dict__.update(import_asic_regs('mp', (11, 0), cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[am.MP1_HWIP][0])))
self.__dict__.update(import_asic_regs(prefix, self.ip_ver[hwip], cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[hwip])))
self.__dict__.update(import_asic_regs('mp', (11, 0), cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[am.MP1_HWIP])))

View File

@@ -249,7 +249,7 @@ class AM_GFX(AM_IP):
self._grbm_select(me=1, pipe=pipe, queue=queue)
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr, self.adev.regCP_HQD_PQ_WPTR_HI.addr + 1)):
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr[0], self.adev.regCP_HQD_PQ_WPTR_HI.addr[0] + 1)):
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
self.adev.regCP_HQD_ACTIVE.write(0x1)

View File

@@ -5,9 +5,10 @@ from tinygrad.helpers import getbits, round_up, fetch
from tinygrad.runtime.autogen import pci
from tinygrad.runtime.support.usb import ASM24Controller
@dataclass(frozen=True)
@dataclass
class AMDReg:
name:str; offset:int; segment:int; fields:dict[str, tuple[int, int]]; bases:tuple[int, ...] # noqa: E702
name:str; offset:int; segment:int; fields:dict[str, tuple[int, int]]; bases:dict[int, tuple[int, ...]] # noqa: E702
def __post_init__(self): self.addr:dict[int, int] = { inst: bases[self.segment] + self.offset for inst, bases in self.bases.items() }
def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0)
def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()}
@@ -15,12 +16,9 @@ class AMDReg:
def fields_mask(self, *names) -> int:
return functools.reduce(int.__or__, ((((1 << (self.fields[nm][1]-self.fields[nm][0]+1)) - 1) << self.fields[nm][0]) for nm in names), 0)
@property
def addr(self): return self.bases[self.segment] + self.offset
@dataclass
class AMDIP:
name:str; version:tuple[int, ...]; bases:tuple[int, ...] # noqa: E702
name:str; version:tuple[int, ...]; bases:dict[int, tuple[int, ...]] # noqa: E702
def __post_init__(self): self.version = fixup_ip_version(self.name, self.version)[0]
@functools.cached_property