am: fini gpu after use (#8556)

* am: fini gpu after use

* mypy
This commit is contained in:
nimlgen
2025-01-10 21:02:34 +03:00
committed by GitHub
parent 6a7f971fa0
commit 337328e409
3 changed files with 29 additions and 5 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, cast
import os, ctypes, ctypes.util, functools, mmap, errno, array, contextlib, sys, select, struct
import os, ctypes, ctypes.util, functools, mmap, errno, array, contextlib, sys, select, struct, atexit
assert sys.platform != 'win32'
from dataclasses import dataclass
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQSignal, HCQProgram, HWInterface
@@ -524,6 +524,8 @@ class PCIIface:
for d in self.dev.devices: d.dev_iface.adev.gmc.on_interrupt()
raise RuntimeError("Device hang detected")
def device_fini(self): self.adev.fini()
class AMDDevice(HCQCompiled):
driverless:bool = not HWInterface.exists('/sys/module/amdgpu') or bool(getenv("AMD_DRIVERLESS", 0))
signals_page:Any = None
@@ -570,6 +572,7 @@ class AMDDevice(HCQCompiled):
super().__init__(device, AMDAllocator(self), AMDRenderer(), AMDCompiler(self.arch), functools.partial(AMDProgram, self),
AMDSignal, AMDComputeQueue, AMDCopyQueue)
atexit.register(self.device_fini)
def create_queue(self, queue_type, ring_size, ctx_save_restore_size=0, eop_buffer_size=0, ctl_stack_size=0, debug_memory_size=0):
ring = self.dev_iface.alloc(ring_size, uncached=True, cpu_access=True)
@@ -584,3 +587,7 @@ class AMDDevice(HCQCompiled):
self.synchronize()
def on_device_hang(self): self.dev_iface.on_device_hang()
def device_fini(self):
self.synchronize()
if hasattr(self.dev_iface, 'device_fini'): self.dev_iface.device_fini()

View File

@@ -275,6 +275,9 @@ class AMDev:
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu, self.gfx, self.sdma]: ip.init()
self.gfx.set_clockgating_state()
def fini(self):
for ip in [self.sdma, self.gfx]: ip.fini()
def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg]
def reg(self, reg:str) -> AMRegister: return self.__dict__[reg]

View File

@@ -7,6 +7,7 @@ from tinygrad.helpers import to_mv, data64, lo32, hi32
class AM_IP:
def __init__(self, adev): self.adev = adev
def init(self): raise NotImplementedError("IP block init must be implemeted")
def fini(self): pass
class AM_SOC21(AM_IP):
def init(self):
@@ -158,6 +159,15 @@ class AM_GFX(AM_IP):
# NOTE: Wait for MEC to be ready. The kernel does udelay here as well.
time.sleep(0.05)
def fini(self):
self._grbm_select(me=1, pipe=0, queue=0)
self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
self.adev.regSPI_COMPUTE_QUEUE_RESET.write(1)
self._grbm_select()
self.adev.regCP_MEC_RS64_CNTL.update(mec_invalidate_icache=1, mec_pipe0_reset=1, mec_pipe1_reset=1, mec_pipe2_reset=1, mec_pipe3_reset=1,
mec_pipe0_active=0, mec_pipe1_active=0, mec_pipe2_active=0, mec_pipe3_active=0, mec_halt=1)
self.adev.regGCVM_CONTEXT0_CNTL.write(0)
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int):
mqd = self.adev.mm.valloc(0x1000, uncached=True, contigous=True)
@@ -260,10 +270,6 @@ class AM_IH(AM_IP):
class AM_SDMA(AM_IP):
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int):
# Stop if something is running...
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_RB_CNTL").update(rb_enable=0)
while not self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_CONTEXT_STATUS").read(idle=1): pass
# Setup the ring
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_RPTR", "", "_HI", 0)
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_WPTR", "", "_HI", 0)
@@ -284,6 +290,14 @@ class AM_SDMA(AM_IP):
self.adev.regSDMA0_F32_CNTL.update(halt=0, th1_reset=0)
self.adev.regSDMA0_CNTL.update(ctxempty_int_enable=1, trap_enable=1)
def fini(self):
self.adev.regSDMA0_QUEUE0_RB_CNTL.update(rb_enable=0)
self.adev.regSDMA0_QUEUE0_IB_CNTL.update(ib_enable=0)
self.adev.regSDMA0_F32_CNTL.update(halt=1, th1_reset=1)
self.adev.regGRBM_SOFT_RESET.write(soft_reset_sdma0=1)
time.sleep(0.01)
self.adev.regGRBM_SOFT_RESET.write(0x0)
class AM_PSP(AM_IP):
def __init__(self, adev):
super().__init__(adev)