more typing in drivers (#11454)

* more typing in drivers

* rm
This commit is contained in:
nimlgen
2025-07-31 23:26:33 +03:00
committed by GitHub
parent bad3cf5731
commit e5b6149dfb
3 changed files with 33 additions and 32 deletions

View File

@@ -169,12 +169,12 @@ class AM_SMU(AM_IP):
self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]))
self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]))
def _smu_cmn_send_msg(self, msg, param=0, debug=False):
def _smu_cmn_send_msg(self, msg:int, param=0, debug=False):
(self.adev.mmMP1_SMN_C2PMSG_90 if not debug else self.adev.mmMP1_SMN_C2PMSG_54).write(0) # resp reg
(self.adev.mmMP1_SMN_C2PMSG_82 if not debug else self.adev.mmMP1_SMN_C2PMSG_53).write(param)
(self.adev.mmMP1_SMN_C2PMSG_66 if not debug else self.adev.mmMP1_SMN_C2PMSG_75).write(msg)
def _send_msg(self, msg, param, read_back_arg=False, timeout=10000, debug=False): # 10s
def _send_msg(self, msg:int, param:int, read_back_arg=False, timeout=10000, debug=False): # default timeout is 10 seconds
self._smu_cmn_send_msg(msg, param, debug=debug)
wait_cond(lambda: (self.adev.mmMP1_SMN_C2PMSG_90 if not debug else self.adev.mmMP1_SMN_C2PMSG_54).read(), value=1, timeout_ms=timeout,
msg=f"SMU msg {msg:#x} timeout")
@@ -414,12 +414,12 @@ class AM_PSP(AM_IP):
def _wait_for_bootloader(self): wait_cond(lambda: self.adev.reg(f"{self.reg_pref}_35").read() & 0x80000000, value=0x80000000, msg="BL not ready")
def _prep_msg1(self, data):
def _prep_msg1(self, data:memoryview):
assert len(data) <= self.msg1_view.nbytes, f"msg1 buffer is too small {len(data):#x} > {self.msg1_view.nbytes:#x}"
self.msg1_view[:len(data)+4] = bytes(data) + b'\x00' * 4
self.adev.gmc.flush_hdp()
def _bootloader_load_component(self, fw, compid):
def _bootloader_load_component(self, fw:int, compid:int):
if fw not in self.adev.fw.sos_fw: return 0
self._wait_for_bootloader()
@@ -458,7 +458,7 @@ class AM_PSP(AM_IP):
wait_cond(lambda: self.adev.reg(f"{self.reg_pref}_64").read() & 0x8000FFFF, value=0x80000000, msg="sOS ring not created")
def _ring_submit(self, cmd):
def _ring_submit(self, cmd:am.struct_psp_gfx_cmd_resp) -> am.struct_psp_gfx_cmd_resp:
msg = am.struct_psp_gfx_rb_frame(fence_value=(prev_wptr:=self.adev.reg(f"{self.reg_pref}_67").read()),
cmd_buf_addr_lo=lo32(self.adev.paddr2mc(self.cmd_paddr)), cmd_buf_addr_hi=hi32(self.adev.paddr2mc(self.cmd_paddr)),
fence_addr_lo=lo32(self.adev.paddr2mc(self.fence_paddr)), fence_addr_hi=hi32(self.adev.paddr2mc(self.fence_paddr)))
@@ -477,7 +477,7 @@ class AM_PSP(AM_IP):
return resp
def _load_ip_fw_cmd(self, fw_types, fw_bytes):
def _load_ip_fw_cmd(self, fw_types:list[int], fw_bytes:memoryview):
self._prep_msg1(fw_bytes)
for fw_type in fw_types:
if DEBUG >= 2: print(f"am {self.adev.devfmt}: loading fw: {am.psp_gfx_fw_type__enumvalues[fw_type]}")
@@ -487,7 +487,7 @@ class AM_PSP(AM_IP):
cmd.cmd.cmd_load_ip_fw.fw_type = fw_type
self._ring_submit(cmd)
def _tmr_load_cmd(self):
def _tmr_load_cmd(self) -> am.struct_psp_gfx_cmd_resp:
cmd = am.struct_psp_gfx_cmd_resp(cmd_id=am.GFX_CMD_ID_SETUP_TMR)
cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.adev.paddr2mc(self.tmr_paddr))
cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_paddr)
@@ -495,7 +495,7 @@ class AM_PSP(AM_IP):
cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_size
return self._ring_submit(cmd)
def _load_toc_cmd(self, toc_size):
def _load_toc_cmd(self, toc_size:int) -> am.struct_psp_gfx_cmd_resp:
cmd = am.struct_psp_gfx_cmd_resp(cmd_id=am.GFX_CMD_ID_LOAD_TOC)
cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.msg1_addr)
cmd.cmd.cmd_load_toc.toc_size = toc_size

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import ctypes, time, array, struct, itertools, dataclasses
from typing import cast
from typing import cast, Any
from tinygrad.runtime.autogen.nv import nv
from tinygrad.helpers import to_mv, lo32, hi32, DEBUG, round_up, round_down, mv_address, fetch, wait_cond
from tinygrad.runtime.support.system import System
@@ -8,7 +8,7 @@ from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.autogen import nv_gpu
@dataclasses.dataclass(frozen=True)
class GRBufDesc: size:int; v:int; p:int; lc:int=0 # noqa: E702
class GRBufDesc: size:int; virt:bool; phys:bool; local:bool=False # noqa: E702
class NV_IP:
def __init__(self, nvdev): self.nvdev = nvdev
@@ -26,13 +26,13 @@ class NVRpcQueue:
self.gsp, self.va, self.queue_va, self.seq = gsp, va, va + self.tx.entryOff, 0
self.queue_mv = to_mv(self.queue_va, self.tx.msgSize * self.tx.msgCount)
def _checksum(self, data):
def _checksum(self, data:bytes):
if (pad_len:=(-len(data)) % 8): data += b'\x00' * pad_len
checksum = 0
for offset in range(0, len(data), 8): checksum ^= struct.unpack_from('Q', data, offset)[0]
return hi32(checksum) ^ lo32(checksum)
def send_rpc(self, func, msg, wait=False):
def send_rpc(self, func:int, msg:bytes, wait=False):
header = nv.rpc_message_header_v(signature=nv.NV_VGPU_MSG_SIGNATURE_VALID, rpc_result=nv.NV_VGPU_MSG_RESULT_RPC_PENDING,
rpc_result_private=nv.NV_VGPU_MSG_RESULT_RPC_PENDING, header_version=(3<<24), function=func, length=len(msg) + 0x20)
@@ -49,7 +49,7 @@ class NVRpcQueue:
self.seq += 1
self.gsp.nvdev.NV_PGSP_QUEUE_HEAD[0].write(0x0)
def wait_resp(self, cmd) -> memoryview:
def wait_resp(self, cmd:int) -> memoryview:
while True:
System.memory_barrier()
if self.rx.readPtr == self.tx.writePtr: continue
@@ -177,7 +177,7 @@ class NV_FLCN(NV_IP):
self.nvdev.NV_PFALCON_FALCON_OS.with_base(self.falcon).write(0x0)
assert self.nvdev.NV_PRISCV_RISCV_CPUCTL.with_base(self.falcon).read_bitfields()['active_stat'] == 1, "GSP Core is not active"
def execute_dma(self, base, cmd, dest, mem_off, sysmem, size):
def execute_dma(self, base:int, cmd:int, dest:int, mem_off:int, sysmem:int, size:int):
wait_cond(lambda: self.nvdev.NV_PFALCON_FALCON_DMATRFCMD.with_base(base).read_bitfields()['full'], value=0, msg="DMA does not progress")
self.nvdev.NV_PFALCON_FALCON_DMATRFBASE.with_base(base).write(lo32(sysmem >> 8))
@@ -194,7 +194,7 @@ class NV_FLCN(NV_IP):
wait_cond(lambda: self.nvdev.NV_PFALCON_FALCON_DMATRFCMD.with_base(base).read_bitfields()['idle'], msg="DMA does not complete")
def start_cpu(self, base):
def start_cpu(self, base:int):
if self.nvdev.NV_PFALCON_FALCON_CPUCTL.with_base(base).read_bitfields()['alias_en'] == 1:
self.nvdev.wreg(base + self.nvdev.NV_PFALCON_FALCON_CPUCTL_ALIAS, 0x2)
else: self.nvdev.NV_PFALCON_FALCON_CPUCTL.with_base(base).write(startcpu=1)
@@ -232,11 +232,11 @@ class NV_FLCN(NV_IP):
if mailbox is not None:
return self.nvdev.NV_PFALCON_FALCON_MAILBOX0.with_base(base).read(), self.nvdev.NV_PFALCON_FALCON_MAILBOX1.with_base(base).read()
def disable_ctx_req(self, base):
def disable_ctx_req(self, base:int):
self.nvdev.NV_PFALCON_FBIF_CTL.with_base(base).update(allow_phys_no_ctx=1)
self.nvdev.NV_PFALCON_FALCON_DMACTL.with_base(base).write(0x0)
def reset(self, base, riscv=False):
def reset(self, base:int, riscv=False):
engine_reg = self.nvdev.NV_PGSP_FALCON_ENGINE if base == self.falcon else self.nvdev.NV_PSEC_FALCON_ENGINE
engine_reg.write(reset=1)
time.sleep(0.1)
@@ -408,10 +408,10 @@ class NV_GSP(NV_IP):
assert self.nvdev.flcn.frts_offset == m.frtsOffset, f"FRTS mismatch: {self.nvdev.flcn.frts_offset} != {m.frtsOffset}"
self.wpr_meta, self.wpr_meta_sysmem = self.nvdev._alloc_boot_struct(m)
def promote_ctx(self, client, subdevice, obj, ctxbufs, bufs=None, virt=None, phys=None):
def promote_ctx(self, client:int, subdevice:int, obj:int, ctxbufs:dict[int, GRBufDesc], bufs=None, virt=None, phys=None):
res, prom = {}, nv_gpu.NV2080_CTRL_GPU_PROMOTE_CTX_PARAMS(entryCount=len(ctxbufs), engineType=0x1, hChanClient=client, hObject=obj)
for i,(buf,desc) in enumerate(ctxbufs.items()):
use_v, use_p = (desc.v if virt is None else virt), (desc.p if phys is None else phys)
use_v, use_p = (desc.virt if virt is None else virt), (desc.phys if phys is None else phys)
x = (bufs or {}).get(buf, self.nvdev.mm.valloc(desc.size, contiguous=True)) # allocate buffers
prom.promoteEntry[i] = nv_gpu.NV2080_CTRL_GPU_PROMOTE_CTX_BUFFER_ENTRY(bufferId=buf, gpuVirtAddr=x.va_addr if use_v else 0, bInitialize=use_p,
gpuPhysAddr=x.paddrs[0][0] if use_p else 0, size=desc.size if use_p else 0, physAttr=0x4 if use_p else 0, bNonmapped=(use_p and not use_v))
@@ -449,10 +449,11 @@ class NV_GSP(NV_IP):
gr_size = _ctx_info(nv_gpu.NV0080_CTRL_FIFO_GET_ENGINE_CONTEXT_PROPERTIES_ENGINE_ID_GRAPHICS, add=0x40000)
patch_size = _ctx_info(nv_gpu.NV0080_CTRL_FIFO_GET_ENGINE_CONTEXT_PROPERTIES_ENGINE_ID_GRAPHICS_PATCH)
cfgs_sizes = {x: _ctx_info(x + 14, align=(2 << 20) if x == 5 else None) for x in range(3, 11)} # indices 310 are mapped to 1724
self.grctx_bufs = {0: GRBufDesc(gr_size, p=1, v=1), 1: GRBufDesc(patch_size, p=1, v=1, lc=1), 2: GRBufDesc(patch_size, p=1, v=1),
**{x: GRBufDesc(cfgs_sizes[x], p=0, v=1) for x in range(3, 7)}, 9: GRBufDesc(cfgs_sizes[9], p=1, v=1),
10: GRBufDesc(cfgs_sizes[10], p=1, v=0), 11: GRBufDesc(cfgs_sizes[10], p=1, v=1)} # NOTE: 11 reuses cfgs_sizes[10]
self.promote_ctx(self.priv_root, subdev, ch_gpfifo, {k:v for k, v in self.grctx_bufs.items() if v.lc == 0})
self.grctx_bufs = {0: GRBufDesc(gr_size, phys=True, virt=True), 1: GRBufDesc(patch_size, phys=True, virt=True, local=True),
2: GRBufDesc(patch_size, phys=True, virt=True), **{x: GRBufDesc(cfgs_sizes[x], phys=False, virt=True) for x in range(3, 7)},
9: GRBufDesc(cfgs_sizes[9], phys=True, virt=True), 10: GRBufDesc(cfgs_sizes[10], phys=True, virt=False),
11: GRBufDesc(cfgs_sizes[10], phys=True, virt=True)} # NOTE: 11 reuses cfgs_sizes[10]
self.promote_ctx(self.priv_root, subdev, ch_gpfifo, {k:v for k, v in self.grctx_bufs.items() if not v.local})
self.rpc_rm_alloc(hParent=ch_gpfifo, hClass=self.compute_class, params=None)
self.rpc_rm_alloc(hParent=ch_gpfifo, hClass=self.dma_class, params=None)
@@ -473,7 +474,7 @@ class NV_GSP(NV_IP):
### RPCs
def rpc_rm_alloc(self, hParent, hClass, params, client=None) -> int:
def rpc_rm_alloc(self, hParent:int, hClass:int, params:Any, client=None) -> int:
if hClass == self.gpfifo_class:
ramfc_alloc = self.nvdev.mm.valloc(0x1000, contiguous=True)
params.ramfcMem = nv_gpu.NV_MEMORY_DESC_PARAMS(base=ramfc_alloc.paddrs[0][0], size=0x200, addressSpace=2, cacheAttrib=0)
@@ -499,7 +500,7 @@ class NV_GSP(NV_IP):
self.promote_ctx(client, self.subdevice, hParent, {k:v for k,v in self.grctx_bufs.items() if k in [0, 1, 2]}, phys_gr_ctx, phys=False)
return obj if hClass != nv_gpu.NV1_ROOT else client
def rpc_rm_control(self, hObject, cmd, params, client=None):
def rpc_rm_control(self, hObject:int, cmd:int, params:Any, client=None):
control_args = nv.rpc_gsp_rm_control_v(hClient=(client:=client or self.priv_root), hObject=hObject, cmd=cmd, flags=0x0,
paramsSize=ctypes.sizeof(params) if params is not None else 0x0)
self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_GSP_RM_CONTROL, bytes(control_args) + (bytes(params) if params is not None else b''))
@@ -511,7 +512,7 @@ class NV_GSP(NV_IP):
cast(nv_gpu.NVC36F_CTRL_CMD_GPFIFO_GET_WORK_SUBMIT_TOKEN_PARAMS, st).workSubmitToken |= (1 << 30)
return st
def rpc_set_page_directory(self, device, hVASpace, pdir_paddr, client=None, pasid=0xffffffff):
def rpc_set_page_directory(self, device:int, hVASpace:int, pdir_paddr:int, client=None, pasid=0xffffffff):
params = nv.struct_NV0080_CTRL_DMA_SET_PAGE_DIRECTORY_PARAMS_v1E_05(physAddress=pdir_paddr,
numEntries=self.nvdev.mm.pte_cnt[0], flags=0x8, hVASpace=hVASpace, pasid=pasid, subDeviceId=1, chId=0) # flags field is all channels.
alloc_args = nv.rpc_set_page_directory_v(hClient=client or self.priv_root, hDevice=device, pasid=pasid, params=params)
@@ -544,7 +545,7 @@ class NV_GSP(NV_IP):
header = nv.PACKED_REGISTRY_TABLE(size=hdr_size + len(entries_bytes) + len(data_bytes), numEntries=len(table))
self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_SET_REGISTRY, bytes(header) + entries_bytes + data_bytes)
def run_cpu_seq(self, seq_buf):
def run_cpu_seq(self, seq_buf:memoryview):
hdr = nv.rpc_run_cpu_sequencer_v17_00.from_address(mv_address(seq_buf))
cmd_iter = iter(seq_buf[ctypes.sizeof(nv.rpc_run_cpu_sequencer_v17_00):].cast('I')[:hdr.cmdIndex])

View File

@@ -71,7 +71,7 @@ class NVMemoryManager(MemoryManager):
def on_range_mapped(self): self.dev.NV_VIRTUAL_FUNCTION_PRIV_MMU_INVALIDATE.write((1 << 0) | (1 << 1) | (1 << 6) | (1 << 31))
class NVDev(PCIDevImplBase):
def __init__(self, devfmt, mmio:MMIOInterface, vram:MMIOInterface, venid:int, subvenid:int, rev:int, bars:dict):
def __init__(self, devfmt:str, mmio:MMIOInterface, vram:MMIOInterface, venid:int, subvenid:int, rev:int, bars:dict):
self.devfmt, self.mmio, self.vram, self.venid, self.subvenid, self.rev, self.bars = devfmt, mmio, vram, venid, subvenid, rev, bars
self.lock_fd = System.flock_acquire(f"nv_{self.devfmt}.lock")
@@ -101,10 +101,10 @@ class NVDev(PCIDevImplBase):
for ip in [self.gsp, self.flcn]: ip.fini_hw()
def reg(self, reg:str) -> NVReg: return self.__dict__[reg]
def wreg(self, addr, value):
def wreg(self, addr:int, value:int):
self.mmio[addr // 4] = value
if NV_DEBUG >= 4: print(f"wreg: {hex(addr)} = {hex(value)}")
def rreg(self, addr): return self.mmio[addr // 4]
def rreg(self, addr:int) -> int: return self.mmio[addr // 4]
def _early_init(self):
self.reg_names:set[str] = set()
@@ -134,12 +134,12 @@ class NVDev(PCIDevImplBase):
self.vram_size = self.reg("NV_PGC6_AON_SECURE_SCRATCH_GROUP_42").read() << 20
def _alloc_boot_struct(self, struct):
def _alloc_boot_struct(self, struct:ctypes.Structure) -> tuple[ctypes.Structure, int]:
va, paddrs = System.alloc_sysmem(sz:=ctypes.sizeof(type(struct)), contiguous=True)
to_mv(va, sz)[:] = bytes(struct)
return type(struct).from_address(va), paddrs[0]
def _download(self, file) -> str:
def _download(self, file:str) -> str:
url = f"https://raw.githubusercontent.com/NVIDIA/open-gpu-kernel-modules/8ec351aeb96a93a4bb69ccc12a542bf8a8df2b6f/{file}"
return fetch(url, subdir="defines").read_text()