diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 4878f8582b..15236c047d 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -169,12 +169,12 @@ class AM_SMU(AM_IP): self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level])) self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level])) - def _smu_cmn_send_msg(self, msg, param=0, debug=False): + def _smu_cmn_send_msg(self, msg:int, param=0, debug=False): (self.adev.mmMP1_SMN_C2PMSG_90 if not debug else self.adev.mmMP1_SMN_C2PMSG_54).write(0) # resp reg (self.adev.mmMP1_SMN_C2PMSG_82 if not debug else self.adev.mmMP1_SMN_C2PMSG_53).write(param) (self.adev.mmMP1_SMN_C2PMSG_66 if not debug else self.adev.mmMP1_SMN_C2PMSG_75).write(msg) - def _send_msg(self, msg, param, read_back_arg=False, timeout=10000, debug=False): # 10s + def _send_msg(self, msg:int, param:int, read_back_arg=False, timeout=10000, debug=False): # default timeout is 10 seconds self._smu_cmn_send_msg(msg, param, debug=debug) wait_cond(lambda: (self.adev.mmMP1_SMN_C2PMSG_90 if not debug else self.adev.mmMP1_SMN_C2PMSG_54).read(), value=1, timeout_ms=timeout, msg=f"SMU msg {msg:#x} timeout") @@ -414,12 +414,12 @@ class AM_PSP(AM_IP): def _wait_for_bootloader(self): wait_cond(lambda: self.adev.reg(f"{self.reg_pref}_35").read() & 0x80000000, value=0x80000000, msg="BL not ready") - def _prep_msg1(self, data): + def _prep_msg1(self, data:memoryview): assert len(data) <= self.msg1_view.nbytes, f"msg1 buffer is too small {len(data):#x} > {self.msg1_view.nbytes:#x}" self.msg1_view[:len(data)+4] = bytes(data) + b'\x00' * 4 self.adev.gmc.flush_hdp() - def _bootloader_load_component(self, fw, compid): + def _bootloader_load_component(self, fw:int, compid:int): if fw not in self.adev.fw.sos_fw: return 0 self._wait_for_bootloader() @@ -458,7 +458,7 @@ class AM_PSP(AM_IP): wait_cond(lambda: self.adev.reg(f"{self.reg_pref}_64").read() & 0x8000FFFF, value=0x80000000, msg="sOS ring not created") - def _ring_submit(self, cmd): + def _ring_submit(self, cmd:am.struct_psp_gfx_cmd_resp) -> am.struct_psp_gfx_cmd_resp: msg = am.struct_psp_gfx_rb_frame(fence_value=(prev_wptr:=self.adev.reg(f"{self.reg_pref}_67").read()), cmd_buf_addr_lo=lo32(self.adev.paddr2mc(self.cmd_paddr)), cmd_buf_addr_hi=hi32(self.adev.paddr2mc(self.cmd_paddr)), fence_addr_lo=lo32(self.adev.paddr2mc(self.fence_paddr)), fence_addr_hi=hi32(self.adev.paddr2mc(self.fence_paddr))) @@ -477,7 +477,7 @@ class AM_PSP(AM_IP): return resp - def _load_ip_fw_cmd(self, fw_types, fw_bytes): + def _load_ip_fw_cmd(self, fw_types:list[int], fw_bytes:memoryview): self._prep_msg1(fw_bytes) for fw_type in fw_types: if DEBUG >= 2: print(f"am {self.adev.devfmt}: loading fw: {am.psp_gfx_fw_type__enumvalues[fw_type]}") @@ -487,7 +487,7 @@ class AM_PSP(AM_IP): cmd.cmd.cmd_load_ip_fw.fw_type = fw_type self._ring_submit(cmd) - def _tmr_load_cmd(self): + def _tmr_load_cmd(self) -> am.struct_psp_gfx_cmd_resp: cmd = am.struct_psp_gfx_cmd_resp(cmd_id=am.GFX_CMD_ID_SETUP_TMR) cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.adev.paddr2mc(self.tmr_paddr)) cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_paddr) @@ -495,7 +495,7 @@ class AM_PSP(AM_IP): cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_size return self._ring_submit(cmd) - def _load_toc_cmd(self, toc_size): + def _load_toc_cmd(self, toc_size:int) -> am.struct_psp_gfx_cmd_resp: cmd = am.struct_psp_gfx_cmd_resp(cmd_id=am.GFX_CMD_ID_LOAD_TOC) cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.msg1_addr) cmd.cmd.cmd_load_toc.toc_size = toc_size diff --git a/tinygrad/runtime/support/nv/ip.py b/tinygrad/runtime/support/nv/ip.py index bae8e0cbe9..eb4867dc38 100644 --- a/tinygrad/runtime/support/nv/ip.py +++ b/tinygrad/runtime/support/nv/ip.py @@ -1,6 +1,6 @@ from __future__ import annotations import ctypes, time, array, struct, itertools, dataclasses -from typing import cast +from typing import cast, Any from tinygrad.runtime.autogen.nv import nv from tinygrad.helpers import to_mv, lo32, hi32, DEBUG, round_up, round_down, mv_address, fetch, wait_cond from tinygrad.runtime.support.system import System @@ -8,7 +8,7 @@ from tinygrad.runtime.support.elf import elf_loader from tinygrad.runtime.autogen import nv_gpu @dataclasses.dataclass(frozen=True) -class GRBufDesc: size:int; v:int; p:int; lc:int=0 # noqa: E702 +class GRBufDesc: size:int; virt:bool; phys:bool; local:bool=False # noqa: E702 class NV_IP: def __init__(self, nvdev): self.nvdev = nvdev @@ -26,13 +26,13 @@ class NVRpcQueue: self.gsp, self.va, self.queue_va, self.seq = gsp, va, va + self.tx.entryOff, 0 self.queue_mv = to_mv(self.queue_va, self.tx.msgSize * self.tx.msgCount) - def _checksum(self, data): + def _checksum(self, data:bytes): if (pad_len:=(-len(data)) % 8): data += b'\x00' * pad_len checksum = 0 for offset in range(0, len(data), 8): checksum ^= struct.unpack_from('Q', data, offset)[0] return hi32(checksum) ^ lo32(checksum) - def send_rpc(self, func, msg, wait=False): + def send_rpc(self, func:int, msg:bytes, wait=False): header = nv.rpc_message_header_v(signature=nv.NV_VGPU_MSG_SIGNATURE_VALID, rpc_result=nv.NV_VGPU_MSG_RESULT_RPC_PENDING, rpc_result_private=nv.NV_VGPU_MSG_RESULT_RPC_PENDING, header_version=(3<<24), function=func, length=len(msg) + 0x20) @@ -49,7 +49,7 @@ class NVRpcQueue: self.seq += 1 self.gsp.nvdev.NV_PGSP_QUEUE_HEAD[0].write(0x0) - def wait_resp(self, cmd) -> memoryview: + def wait_resp(self, cmd:int) -> memoryview: while True: System.memory_barrier() if self.rx.readPtr == self.tx.writePtr: continue @@ -177,7 +177,7 @@ class NV_FLCN(NV_IP): self.nvdev.NV_PFALCON_FALCON_OS.with_base(self.falcon).write(0x0) assert self.nvdev.NV_PRISCV_RISCV_CPUCTL.with_base(self.falcon).read_bitfields()['active_stat'] == 1, "GSP Core is not active" - def execute_dma(self, base, cmd, dest, mem_off, sysmem, size): + def execute_dma(self, base:int, cmd:int, dest:int, mem_off:int, sysmem:int, size:int): wait_cond(lambda: self.nvdev.NV_PFALCON_FALCON_DMATRFCMD.with_base(base).read_bitfields()['full'], value=0, msg="DMA does not progress") self.nvdev.NV_PFALCON_FALCON_DMATRFBASE.with_base(base).write(lo32(sysmem >> 8)) @@ -194,7 +194,7 @@ class NV_FLCN(NV_IP): wait_cond(lambda: self.nvdev.NV_PFALCON_FALCON_DMATRFCMD.with_base(base).read_bitfields()['idle'], msg="DMA does not complete") - def start_cpu(self, base): + def start_cpu(self, base:int): if self.nvdev.NV_PFALCON_FALCON_CPUCTL.with_base(base).read_bitfields()['alias_en'] == 1: self.nvdev.wreg(base + self.nvdev.NV_PFALCON_FALCON_CPUCTL_ALIAS, 0x2) else: self.nvdev.NV_PFALCON_FALCON_CPUCTL.with_base(base).write(startcpu=1) @@ -232,11 +232,11 @@ class NV_FLCN(NV_IP): if mailbox is not None: return self.nvdev.NV_PFALCON_FALCON_MAILBOX0.with_base(base).read(), self.nvdev.NV_PFALCON_FALCON_MAILBOX1.with_base(base).read() - def disable_ctx_req(self, base): + def disable_ctx_req(self, base:int): self.nvdev.NV_PFALCON_FBIF_CTL.with_base(base).update(allow_phys_no_ctx=1) self.nvdev.NV_PFALCON_FALCON_DMACTL.with_base(base).write(0x0) - def reset(self, base, riscv=False): + def reset(self, base:int, riscv=False): engine_reg = self.nvdev.NV_PGSP_FALCON_ENGINE if base == self.falcon else self.nvdev.NV_PSEC_FALCON_ENGINE engine_reg.write(reset=1) time.sleep(0.1) @@ -408,10 +408,10 @@ class NV_GSP(NV_IP): assert self.nvdev.flcn.frts_offset == m.frtsOffset, f"FRTS mismatch: {self.nvdev.flcn.frts_offset} != {m.frtsOffset}" self.wpr_meta, self.wpr_meta_sysmem = self.nvdev._alloc_boot_struct(m) - def promote_ctx(self, client, subdevice, obj, ctxbufs, bufs=None, virt=None, phys=None): + def promote_ctx(self, client:int, subdevice:int, obj:int, ctxbufs:dict[int, GRBufDesc], bufs=None, virt=None, phys=None): res, prom = {}, nv_gpu.NV2080_CTRL_GPU_PROMOTE_CTX_PARAMS(entryCount=len(ctxbufs), engineType=0x1, hChanClient=client, hObject=obj) for i,(buf,desc) in enumerate(ctxbufs.items()): - use_v, use_p = (desc.v if virt is None else virt), (desc.p if phys is None else phys) + use_v, use_p = (desc.virt if virt is None else virt), (desc.phys if phys is None else phys) x = (bufs or {}).get(buf, self.nvdev.mm.valloc(desc.size, contiguous=True)) # allocate buffers prom.promoteEntry[i] = nv_gpu.NV2080_CTRL_GPU_PROMOTE_CTX_BUFFER_ENTRY(bufferId=buf, gpuVirtAddr=x.va_addr if use_v else 0, bInitialize=use_p, gpuPhysAddr=x.paddrs[0][0] if use_p else 0, size=desc.size if use_p else 0, physAttr=0x4 if use_p else 0, bNonmapped=(use_p and not use_v)) @@ -449,10 +449,11 @@ class NV_GSP(NV_IP): gr_size = _ctx_info(nv_gpu.NV0080_CTRL_FIFO_GET_ENGINE_CONTEXT_PROPERTIES_ENGINE_ID_GRAPHICS, add=0x40000) patch_size = _ctx_info(nv_gpu.NV0080_CTRL_FIFO_GET_ENGINE_CONTEXT_PROPERTIES_ENGINE_ID_GRAPHICS_PATCH) cfgs_sizes = {x: _ctx_info(x + 14, align=(2 << 20) if x == 5 else None) for x in range(3, 11)} # indices 3–10 are mapped to 17–24 - self.grctx_bufs = {0: GRBufDesc(gr_size, p=1, v=1), 1: GRBufDesc(patch_size, p=1, v=1, lc=1), 2: GRBufDesc(patch_size, p=1, v=1), - **{x: GRBufDesc(cfgs_sizes[x], p=0, v=1) for x in range(3, 7)}, 9: GRBufDesc(cfgs_sizes[9], p=1, v=1), - 10: GRBufDesc(cfgs_sizes[10], p=1, v=0), 11: GRBufDesc(cfgs_sizes[10], p=1, v=1)} # NOTE: 11 reuses cfgs_sizes[10] - self.promote_ctx(self.priv_root, subdev, ch_gpfifo, {k:v for k, v in self.grctx_bufs.items() if v.lc == 0}) + self.grctx_bufs = {0: GRBufDesc(gr_size, phys=True, virt=True), 1: GRBufDesc(patch_size, phys=True, virt=True, local=True), + 2: GRBufDesc(patch_size, phys=True, virt=True), **{x: GRBufDesc(cfgs_sizes[x], phys=False, virt=True) for x in range(3, 7)}, + 9: GRBufDesc(cfgs_sizes[9], phys=True, virt=True), 10: GRBufDesc(cfgs_sizes[10], phys=True, virt=False), + 11: GRBufDesc(cfgs_sizes[10], phys=True, virt=True)} # NOTE: 11 reuses cfgs_sizes[10] + self.promote_ctx(self.priv_root, subdev, ch_gpfifo, {k:v for k, v in self.grctx_bufs.items() if not v.local}) self.rpc_rm_alloc(hParent=ch_gpfifo, hClass=self.compute_class, params=None) self.rpc_rm_alloc(hParent=ch_gpfifo, hClass=self.dma_class, params=None) @@ -473,7 +474,7 @@ class NV_GSP(NV_IP): ### RPCs - def rpc_rm_alloc(self, hParent, hClass, params, client=None) -> int: + def rpc_rm_alloc(self, hParent:int, hClass:int, params:Any, client=None) -> int: if hClass == self.gpfifo_class: ramfc_alloc = self.nvdev.mm.valloc(0x1000, contiguous=True) params.ramfcMem = nv_gpu.NV_MEMORY_DESC_PARAMS(base=ramfc_alloc.paddrs[0][0], size=0x200, addressSpace=2, cacheAttrib=0) @@ -499,7 +500,7 @@ class NV_GSP(NV_IP): self.promote_ctx(client, self.subdevice, hParent, {k:v for k,v in self.grctx_bufs.items() if k in [0, 1, 2]}, phys_gr_ctx, phys=False) return obj if hClass != nv_gpu.NV1_ROOT else client - def rpc_rm_control(self, hObject, cmd, params, client=None): + def rpc_rm_control(self, hObject:int, cmd:int, params:Any, client=None): control_args = nv.rpc_gsp_rm_control_v(hClient=(client:=client or self.priv_root), hObject=hObject, cmd=cmd, flags=0x0, paramsSize=ctypes.sizeof(params) if params is not None else 0x0) self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_GSP_RM_CONTROL, bytes(control_args) + (bytes(params) if params is not None else b'')) @@ -511,7 +512,7 @@ class NV_GSP(NV_IP): cast(nv_gpu.NVC36F_CTRL_CMD_GPFIFO_GET_WORK_SUBMIT_TOKEN_PARAMS, st).workSubmitToken |= (1 << 30) return st - def rpc_set_page_directory(self, device, hVASpace, pdir_paddr, client=None, pasid=0xffffffff): + def rpc_set_page_directory(self, device:int, hVASpace:int, pdir_paddr:int, client=None, pasid=0xffffffff): params = nv.struct_NV0080_CTRL_DMA_SET_PAGE_DIRECTORY_PARAMS_v1E_05(physAddress=pdir_paddr, numEntries=self.nvdev.mm.pte_cnt[0], flags=0x8, hVASpace=hVASpace, pasid=pasid, subDeviceId=1, chId=0) # flags field is all channels. alloc_args = nv.rpc_set_page_directory_v(hClient=client or self.priv_root, hDevice=device, pasid=pasid, params=params) @@ -544,7 +545,7 @@ class NV_GSP(NV_IP): header = nv.PACKED_REGISTRY_TABLE(size=hdr_size + len(entries_bytes) + len(data_bytes), numEntries=len(table)) self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_SET_REGISTRY, bytes(header) + entries_bytes + data_bytes) - def run_cpu_seq(self, seq_buf): + def run_cpu_seq(self, seq_buf:memoryview): hdr = nv.rpc_run_cpu_sequencer_v17_00.from_address(mv_address(seq_buf)) cmd_iter = iter(seq_buf[ctypes.sizeof(nv.rpc_run_cpu_sequencer_v17_00):].cast('I')[:hdr.cmdIndex]) diff --git a/tinygrad/runtime/support/nv/nvdev.py b/tinygrad/runtime/support/nv/nvdev.py index d5cd11382a..ce25760964 100644 --- a/tinygrad/runtime/support/nv/nvdev.py +++ b/tinygrad/runtime/support/nv/nvdev.py @@ -71,7 +71,7 @@ class NVMemoryManager(MemoryManager): def on_range_mapped(self): self.dev.NV_VIRTUAL_FUNCTION_PRIV_MMU_INVALIDATE.write((1 << 0) | (1 << 1) | (1 << 6) | (1 << 31)) class NVDev(PCIDevImplBase): - def __init__(self, devfmt, mmio:MMIOInterface, vram:MMIOInterface, venid:int, subvenid:int, rev:int, bars:dict): + def __init__(self, devfmt:str, mmio:MMIOInterface, vram:MMIOInterface, venid:int, subvenid:int, rev:int, bars:dict): self.devfmt, self.mmio, self.vram, self.venid, self.subvenid, self.rev, self.bars = devfmt, mmio, vram, venid, subvenid, rev, bars self.lock_fd = System.flock_acquire(f"nv_{self.devfmt}.lock") @@ -101,10 +101,10 @@ class NVDev(PCIDevImplBase): for ip in [self.gsp, self.flcn]: ip.fini_hw() def reg(self, reg:str) -> NVReg: return self.__dict__[reg] - def wreg(self, addr, value): + def wreg(self, addr:int, value:int): self.mmio[addr // 4] = value if NV_DEBUG >= 4: print(f"wreg: {hex(addr)} = {hex(value)}") - def rreg(self, addr): return self.mmio[addr // 4] + def rreg(self, addr:int) -> int: return self.mmio[addr // 4] def _early_init(self): self.reg_names:set[str] = set() @@ -134,12 +134,12 @@ class NVDev(PCIDevImplBase): self.vram_size = self.reg("NV_PGC6_AON_SECURE_SCRATCH_GROUP_42").read() << 20 - def _alloc_boot_struct(self, struct): + def _alloc_boot_struct(self, struct:ctypes.Structure) -> tuple[ctypes.Structure, int]: va, paddrs = System.alloc_sysmem(sz:=ctypes.sizeof(type(struct)), contiguous=True) to_mv(va, sz)[:] = bytes(struct) return type(struct).from_address(va), paddrs[0] - def _download(self, file) -> str: + def _download(self, file:str) -> str: url = f"https://raw.githubusercontent.com/NVIDIA/open-gpu-kernel-modules/8ec351aeb96a93a4bb69ccc12a542bf8a8df2b6f/{file}" return fetch(url, subdir="defines").read_text()