diff --git a/tinygrad/runtime/support/nv/ip.py b/tinygrad/runtime/support/nv/ip.py index fc763ab985..bae8e0cbe9 100644 --- a/tinygrad/runtime/support/nv/ip.py +++ b/tinygrad/runtime/support/nv/ip.py @@ -14,6 +14,7 @@ class NV_IP: def __init__(self, nvdev): self.nvdev = nvdev def init_sw(self): pass # Prepare sw/allocations for this IP def init_hw(self): pass # Initialize hw for this IP + def fini_hw(self): pass # Finalize hw for this IP class NVRpcQueue: def __init__(self, gsp:NV_GSP, va:int, completion_q_va:int|None=None): @@ -468,6 +469,8 @@ class NV_GSP(NV_IP): self.priv_root = 0xc1e00004 self.init_golden_image() + def fini_hw(self): self.rpc_unloading_guest_driver() + ### RPCs def rpc_rm_alloc(self, hParent, hClass, params, client=None) -> int: @@ -523,6 +526,11 @@ class NV_GSP(NV_IP): bIsPassthru=1, PCIDeviceID=self.nvdev.venid, PCISubDeviceID=self.nvdev.subvenid, PCIRevisionID=self.nvdev.rev, maxUserVa=0x7ffffffff000) self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_GSP_SET_SYSTEM_INFO, bytes(data)) + def rpc_unloading_guest_driver(self): + data = nv.rpc_unloading_guest_driver_v(bInPMTransition=0, bGc6Entering=0, newLevel=(__GPU_STATE_FLAGS_FAST_UNLOAD:=1 << 6)) + self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_UNLOADING_GUEST_DRIVER, bytes(data)) + self.stat_q.wait_resp(nv.NV_VGPU_MSG_FUNCTION_UNLOADING_GUEST_DRIVER) + def rpc_set_registry_table(self): table = {'RMForcePcieConfigSave': 0x1, 'RMSecBusResetEnable': 0x1} entries_bytes, data_bytes = bytes(), bytes() diff --git a/tinygrad/runtime/support/nv/nvdev.py b/tinygrad/runtime/support/nv/nvdev.py index 47a2021e07..d5cd11382a 100644 --- a/tinygrad/runtime/support/nv/nvdev.py +++ b/tinygrad/runtime/support/nv/nvdev.py @@ -97,7 +97,8 @@ class NVDev(PCIDevImplBase): for ip in [self.flcn, self.gsp]: ip.init_sw() for ip in [self.flcn, self.gsp]: ip.init_hw() - def fini(self): System.pci_reset(self.devfmt) # Reset the device to clean up resources. TODO: Consider a warm start process. + def fini(self): + for ip in [self.gsp, self.flcn]: ip.fini_hw() def reg(self, reg:str) -> NVReg: return self.__dict__[reg] def wreg(self, addr, value):