mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
amd: support xcc in regs (#11670)
* amd: support xcc in regs * mockamd * typong
This commit is contained in:
@@ -36,7 +36,8 @@ def main():
|
||||
reg_names = {}
|
||||
dev = PCIIface(None, 0)
|
||||
for x, y in dev.dev_impl.__dict__.items():
|
||||
if isinstance(y, AMRegister): reg_names[y.addr] = x
|
||||
if isinstance(y, AMRegister):
|
||||
for inst, addr in y.addr.keys(): reg_names[addr] = f"{x}, xcc={inst}"
|
||||
|
||||
with open(sys.argv[1], 'r') as f:
|
||||
log_content = log_content_them = f.read()
|
||||
|
||||
@@ -87,16 +87,19 @@ class AMDDriver(VirtDriver):
|
||||
functools.partial(TextFileDesc, text=gpu_props.format(drm_render_minor=gpu_id))),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0',
|
||||
functools.partial(DirFileDesc, child_names=[str(am.GC_HWID), str(am.SDMA0_HWID), str(am.NBIF_HWID)])),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}', functools.partial(DirFileDesc, child_names=['0'])),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/major', functools.partial(TextFileDesc, text='11')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/minor', functools.partial(TextFileDesc, text='0')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/revision', functools.partial(TextFileDesc, text='0')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/base_addr',
|
||||
functools.partial(TextFileDesc, text='0x00001260\n0x0000A000\n0x0001C000\n0x02402C00')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}', functools.partial(DirFileDesc, child_names=['0'])),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/major', functools.partial(TextFileDesc, text='6')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/minor', functools.partial(TextFileDesc, text='0')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/revision', functools.partial(TextFileDesc, text='0')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/base_addr',
|
||||
functools.partial(TextFileDesc, text='0x00001260\n0x0000A000\n0x0001C000\n0x02402C00')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}', functools.partial(DirFileDesc, child_names=['0'])),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/major', functools.partial(TextFileDesc, text='4')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/minor', functools.partial(TextFileDesc, text='3')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/revision', functools.partial(TextFileDesc, text='0')),
|
||||
|
||||
@@ -45,12 +45,12 @@ class AMDComputeQueue(HWQueue):
|
||||
|
||||
def wreg(self, reg:AMDReg, *args:sint, **kwargs:int):
|
||||
if bool(args) == bool(kwargs): raise RuntimeError('One (and only one) of *args or **kwargs must be specified')
|
||||
if self.pm4.PACKET3_SET_SH_REG_START <= reg.addr < self.pm4.PACKET3_SET_SH_REG_END:
|
||||
if self.pm4.PACKET3_SET_SH_REG_START <= reg.addr[0] < self.pm4.PACKET3_SET_SH_REG_END:
|
||||
set_packet, set_packet_start = self.pm4.PACKET3_SET_SH_REG, self.pm4.PACKET3_SET_SH_REG_START
|
||||
elif self.pm4.PACKET3_SET_UCONFIG_REG_START <= reg.addr < self.pm4.PACKET3_SET_UCONFIG_REG_START + 2**16-1:
|
||||
elif self.pm4.PACKET3_SET_UCONFIG_REG_START <= reg.addr[0] < self.pm4.PACKET3_SET_UCONFIG_REG_START + 2**16-1:
|
||||
set_packet, set_packet_start = self.pm4.PACKET3_SET_UCONFIG_REG, self.pm4.PACKET3_SET_UCONFIG_REG_START
|
||||
else: raise RuntimeError(f'Cannot set {reg.name} ({reg.addr}) via pm4 packet')
|
||||
self.pkt3(set_packet, reg.addr - set_packet_start, *(args or (reg.encode(**kwargs),)))
|
||||
else: raise RuntimeError(f'Cannot set {reg.name} ({reg.addr[0]}) via pm4 packet')
|
||||
self.pkt3(set_packet, reg.addr[0] - set_packet_start, *(args or (reg.encode(**kwargs),)))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def pred_exec(self, xcc_mask:int):
|
||||
@@ -119,8 +119,8 @@ class AMDComputeQueue(HWQueue):
|
||||
|
||||
def memory_barrier(self):
|
||||
pf = '' if self.nbio.version[0] == 2 else '0' if self.nbio.version[:2] != (7, 11) else '1'
|
||||
self.wait_reg_mem(reg_req=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_REQ').addr,
|
||||
reg_done=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr, value=0xffffffff)
|
||||
self.wait_reg_mem(reg_req=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_REQ').addr[0],
|
||||
reg_done=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff)
|
||||
self.acquire_mem()
|
||||
return self
|
||||
|
||||
@@ -190,17 +190,17 @@ class AMDComputeQueue(HWQueue):
|
||||
self.wreg(self.gc.regGRBM_GFX_INDEX, se_index=se, instance_broadcast_writes=1)
|
||||
# Wait for FINISH_PENDING==0
|
||||
self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_EQ),
|
||||
self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_pending'), 4)
|
||||
self.gc.regSQ_THREAD_TRACE_STATUS.addr[0], 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_pending'), 4)
|
||||
# Wait for FINISH_DONE!=0
|
||||
self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_NEQ),
|
||||
self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_done'), 4)
|
||||
self.gc.regSQ_THREAD_TRACE_STATUS.addr[0], 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_done'), 4)
|
||||
# Disable SQTT
|
||||
self.sqtt_config(tracing=False)
|
||||
# Wait for BUSY==0
|
||||
self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_EQ),
|
||||
self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('busy'), 4)
|
||||
self.gc.regSQ_THREAD_TRACE_STATUS.addr[0], 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('busy'), 4)
|
||||
# Copy WPTR to memory (src_sel = perf, dst_sel = tc_l2, wr_confirm = True)
|
||||
self.pkt3(self.pm4.PACKET3_COPY_DATA, 1 << 20 | 2 << 8 | 4, self.gc.regSQ_THREAD_TRACE_WPTR.addr, 0, *data64_le(wptrs.va_addr+(se*4)))
|
||||
self.pkt3(self.pm4.PACKET3_COPY_DATA, 1 << 20 | 2 << 8 | 4, self.gc.regSQ_THREAD_TRACE_WPTR.addr[0], 0, *data64_le(wptrs.va_addr+(se*4)))
|
||||
# Restore global broadcasting
|
||||
self.wreg(self.gc.regGRBM_GFX_INDEX, se_broadcast_writes=1, sa_broadcast_writes=1, instance_broadcast_writes=1)
|
||||
self.spi_config(tracing=False)
|
||||
@@ -539,10 +539,10 @@ class KFDIface:
|
||||
self.props = {(p:=l.split())[0]: int(p[1]) for l in FileIOInterface(f"{kfd_topo_path}/{KFDIface.gpus[device_id]}/properties").read().splitlines()}
|
||||
ip_base = f"/sys/class/drm/renderD{self.props['drm_render_minor']}/device/ip_discovery/die/0"
|
||||
id2ip = {am.GC_HWID: am.GC_HWIP, am.SDMA0_HWID: am.SDMA0_HWIP, am.NBIF_HWID: am.NBIF_HWIP}
|
||||
self.ip_versions = {id2ip[int(hwid)]:tuple(int(FileIOInterface(f'{ip_base}/{hwid}/0/{part}').read()) for part in ['major', 'minor', 'revision'])
|
||||
for hwid in FileIOInterface(ip_base).listdir() if hwid.isnumeric() and int(hwid) in id2ip}
|
||||
self.ip_offsets = {id2ip[int(hwid)]:tuple(int(x, 16) for x in FileIOInterface(f'{ip_base}/{hwid}/0/base_addr').read().splitlines())
|
||||
for hwid in FileIOInterface(ip_base).listdir() if hwid.isnumeric() and int(hwid) in id2ip}
|
||||
ip_hw = [(id2ip[int(hwid)], int(hwid)) for hwid in FileIOInterface(ip_base).listdir() if hwid.isnumeric() and int(hwid) in id2ip]
|
||||
self.ip_versions = {ip:tuple(int(FileIOInterface(f'{ip_base}/{hw}/0/{part}').read()) for part in ['major','minor','revision']) for ip,hw in ip_hw}
|
||||
self.ip_offsets = {ip:{int(i):tuple(int(x, 16) for x in FileIOInterface(f'{ip_base}/{hw}/{i}/base_addr').read().splitlines())
|
||||
for i in FileIOInterface(f'{ip_base}/{hw}').listdir()} for ip,hw in ip_hw }
|
||||
self.drm_fd = FileIOInterface(f"/dev/dri/renderD{self.props['drm_render_minor']}", os.O_RDWR)
|
||||
|
||||
kfd.AMDKFD_IOC_ACQUIRE_VM(KFDIface.kfd, drm_fd=self.drm_fd.fd, gpu_id=self.gpu_id)
|
||||
@@ -653,8 +653,7 @@ class PCIIface(PCIIfaceBase):
|
||||
|
||||
def _setup_adev(self, name, vram:MMIOInterface, doorbell:MMIOInterface, mmio:MMIOInterface, dma_regions:list[tuple[int, MMIOInterface]]|None=None):
|
||||
self.dev_impl:AMDev = AMDev(name, vram, doorbell, mmio, dma_regions)
|
||||
self.ip_versions = self.dev_impl.ip_ver
|
||||
self.ip_offsets = {hwip: tuple(instances[0]) for hwip,instances in self.dev_impl.regs_offset.items()}
|
||||
self.ip_offsets, self.ip_versions = self.dev_impl.regs_offset, self.dev_impl.ip_ver
|
||||
|
||||
gfxver = int(f"{self.dev_impl.ip_ver[am.GC_HWIP][0]:02d}{self.dev_impl.ip_ver[am.GC_HWIP][1]:02d}{self.dev_impl.ip_ver[am.GC_HWIP][2]:02d}")
|
||||
array_count = self.dev_impl.gc_info.gc_num_sa_per_se * self.dev_impl.gc_info.gc_num_se
|
||||
@@ -762,7 +761,7 @@ class AMDDevice(HCQCompiled):
|
||||
|
||||
nbio_name = 'nbio' if self.target[0] < 12 else 'nbif'
|
||||
nbio_pad = (0,) if self.target[0] == 9 else ()
|
||||
self.nbio = AMDIP(nbio_name, self.iface.ip_versions[am.NBIF_HWIP], nbio_pad+self.iface.ip_offsets[am.NBIF_HWIP])
|
||||
self.nbio = AMDIP(nbio_name, self.iface.ip_versions[am.NBIF_HWIP], {i:nbio_pad+x for i,x in self.iface.ip_offsets[am.NBIF_HWIP].items()})
|
||||
|
||||
self.compute_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_COMPUTE, 0x2000 if self.is_usb() else (16 << 20), eop_buffer_size=0x1000,
|
||||
ctx_save_restore_size=0 if self.is_am() else wg_data_size + ctl_stack_size, ctl_stack_size=ctl_stack_size, debug_memory_size=debug_memory_size)
|
||||
|
||||
@@ -10,16 +10,16 @@ from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU
|
||||
|
||||
AM_DEBUG = getenv("AM_DEBUG", 0)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass
|
||||
class AMRegister(AMDReg):
|
||||
adev:AMDev
|
||||
|
||||
def read(self): return self.adev.rreg(self.addr)
|
||||
def read_bitfields(self) -> dict[str, int]: return self.decode(self.read())
|
||||
def read(self, inst=0): return self.adev.rreg(self.addr[inst])
|
||||
def read_bitfields(self, inst=0) -> dict[str, int]: return self.decode(self.read(inst=inst))
|
||||
|
||||
def write(self, _am_val:int=0, **kwargs): self.adev.wreg(self.addr, _am_val | self.encode(**kwargs))
|
||||
def write(self, _am_val:int=0, inst=0, **kwargs): self.adev.wreg(self.addr[inst], _am_val | self.encode(**kwargs))
|
||||
|
||||
def update(self, **kwargs): self.write(self.read() & ~self.fields_mask(*kwargs.keys()), **kwargs)
|
||||
def update(self, inst=0, **kwargs): self.write(self.read(inst=inst) & ~self.fields_mask(*kwargs.keys()), inst=inst, **kwargs)
|
||||
|
||||
class AMFirmware:
|
||||
def __init__(self, adev):
|
||||
@@ -254,6 +254,6 @@ class AMDev(PCIDevImplBase):
|
||||
("nbio" if self.ip_ver[am.GC_HWIP] < (12,0,0) else "nbif", am.NBIO_HWIP)]
|
||||
|
||||
for prefix, hwip in mods:
|
||||
self.__dict__.update(import_asic_regs(prefix, self.ip_ver[hwip], cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[hwip][0])))
|
||||
self.__dict__.update(import_asic_regs('mp', (11, 0), cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[am.MP1_HWIP][0])))
|
||||
self.__dict__.update(import_asic_regs(prefix, self.ip_ver[hwip], cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[hwip])))
|
||||
self.__dict__.update(import_asic_regs('mp', (11, 0), cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[am.MP1_HWIP])))
|
||||
|
||||
|
||||
@@ -249,7 +249,7 @@ class AM_GFX(AM_IP):
|
||||
self._grbm_select(me=1, pipe=pipe, queue=queue)
|
||||
|
||||
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
|
||||
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr, self.adev.regCP_HQD_PQ_WPTR_HI.addr + 1)):
|
||||
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr[0], self.adev.regCP_HQD_PQ_WPTR_HI.addr[0] + 1)):
|
||||
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
|
||||
self.adev.regCP_HQD_ACTIVE.write(0x1)
|
||||
|
||||
|
||||
@@ -5,9 +5,10 @@ from tinygrad.helpers import getbits, round_up, fetch
|
||||
from tinygrad.runtime.autogen import pci
|
||||
from tinygrad.runtime.support.usb import ASM24Controller
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass
|
||||
class AMDReg:
|
||||
name:str; offset:int; segment:int; fields:dict[str, tuple[int, int]]; bases:tuple[int, ...] # noqa: E702
|
||||
name:str; offset:int; segment:int; fields:dict[str, tuple[int, int]]; bases:dict[int, tuple[int, ...]] # noqa: E702
|
||||
def __post_init__(self): self.addr:dict[int, int] = { inst: bases[self.segment] + self.offset for inst, bases in self.bases.items() }
|
||||
|
||||
def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0)
|
||||
def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()}
|
||||
@@ -15,12 +16,9 @@ class AMDReg:
|
||||
def fields_mask(self, *names) -> int:
|
||||
return functools.reduce(int.__or__, ((((1 << (self.fields[nm][1]-self.fields[nm][0]+1)) - 1) << self.fields[nm][0]) for nm in names), 0)
|
||||
|
||||
@property
|
||||
def addr(self): return self.bases[self.segment] + self.offset
|
||||
|
||||
@dataclass
|
||||
class AMDIP:
|
||||
name:str; version:tuple[int, ...]; bases:tuple[int, ...] # noqa: E702
|
||||
name:str; version:tuple[int, ...]; bases:dict[int, tuple[int, ...]] # noqa: E702
|
||||
def __post_init__(self): self.version = fixup_ip_version(self.name, self.version)[0]
|
||||
|
||||
@functools.cached_property
|
||||
|
||||
Reference in New Issue
Block a user