From 337328e40992520fe2dee9943eebb62d255bbf72 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Fri, 10 Jan 2025 21:02:34 +0300 Subject: [PATCH] am: fini gpu after use (#8556) * am: fini gpu after use * mypy --- tinygrad/runtime/ops_amd.py | 9 ++++++++- tinygrad/runtime/support/am/amdev.py | 3 +++ tinygrad/runtime/support/am/ip.py | 22 ++++++++++++++++++---- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 8b8cc6444e..b81801cc48 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Any, cast -import os, ctypes, ctypes.util, functools, mmap, errno, array, contextlib, sys, select, struct +import os, ctypes, ctypes.util, functools, mmap, errno, array, contextlib, sys, select, struct, atexit assert sys.platform != 'win32' from dataclasses import dataclass from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQSignal, HCQProgram, HWInterface @@ -524,6 +524,8 @@ class PCIIface: for d in self.dev.devices: d.dev_iface.adev.gmc.on_interrupt() raise RuntimeError("Device hang detected") + def device_fini(self): self.adev.fini() + class AMDDevice(HCQCompiled): driverless:bool = not HWInterface.exists('/sys/module/amdgpu') or bool(getenv("AMD_DRIVERLESS", 0)) signals_page:Any = None @@ -570,6 +572,7 @@ class AMDDevice(HCQCompiled): super().__init__(device, AMDAllocator(self), AMDRenderer(), AMDCompiler(self.arch), functools.partial(AMDProgram, self), AMDSignal, AMDComputeQueue, AMDCopyQueue) + atexit.register(self.device_fini) def create_queue(self, queue_type, ring_size, ctx_save_restore_size=0, eop_buffer_size=0, ctl_stack_size=0, debug_memory_size=0): ring = self.dev_iface.alloc(ring_size, uncached=True, cpu_access=True) @@ -584,3 +587,7 @@ class AMDDevice(HCQCompiled): self.synchronize() def on_device_hang(self): self.dev_iface.on_device_hang() + + def device_fini(self): + self.synchronize() + if hasattr(self.dev_iface, 'device_fini'): self.dev_iface.device_fini() diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index ca0d8abd7b..50278e8bea 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -275,6 +275,9 @@ class AMDev: for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu, self.gfx, self.sdma]: ip.init() self.gfx.set_clockgating_state() + def fini(self): + for ip in [self.sdma, self.gfx]: ip.fini() + def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg] def reg(self, reg:str) -> AMRegister: return self.__dict__[reg] diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index b9fd912a3e..1b7eebbd9b 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -7,6 +7,7 @@ from tinygrad.helpers import to_mv, data64, lo32, hi32 class AM_IP: def __init__(self, adev): self.adev = adev def init(self): raise NotImplementedError("IP block init must be implemeted") + def fini(self): pass class AM_SOC21(AM_IP): def init(self): @@ -158,6 +159,15 @@ class AM_GFX(AM_IP): # NOTE: Wait for MEC to be ready. The kernel does udelay here as well. time.sleep(0.05) + def fini(self): + self._grbm_select(me=1, pipe=0, queue=0) + self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2) # 1 - DRAIN_PIPE; 2 - RESET_WAVES + self.adev.regSPI_COMPUTE_QUEUE_RESET.write(1) + self._grbm_select() + self.adev.regCP_MEC_RS64_CNTL.update(mec_invalidate_icache=1, mec_pipe0_reset=1, mec_pipe1_reset=1, mec_pipe2_reset=1, mec_pipe3_reset=1, + mec_pipe0_active=0, mec_pipe1_active=0, mec_pipe2_active=0, mec_pipe3_active=0, mec_halt=1) + self.adev.regGCVM_CONTEXT0_CNTL.write(0) + def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int): mqd = self.adev.mm.valloc(0x1000, uncached=True, contigous=True) @@ -260,10 +270,6 @@ class AM_IH(AM_IP): class AM_SDMA(AM_IP): def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int): - # Stop if something is running... - self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_RB_CNTL").update(rb_enable=0) - while not self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_CONTEXT_STATUS").read(idle=1): pass - # Setup the ring self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_RPTR", "", "_HI", 0) self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_WPTR", "", "_HI", 0) @@ -284,6 +290,14 @@ class AM_SDMA(AM_IP): self.adev.regSDMA0_F32_CNTL.update(halt=0, th1_reset=0) self.adev.regSDMA0_CNTL.update(ctxempty_int_enable=1, trap_enable=1) + def fini(self): + self.adev.regSDMA0_QUEUE0_RB_CNTL.update(rb_enable=0) + self.adev.regSDMA0_QUEUE0_IB_CNTL.update(ib_enable=0) + self.adev.regSDMA0_F32_CNTL.update(halt=1, th1_reset=1) + self.adev.regGRBM_SOFT_RESET.write(soft_reset_sdma0=1) + time.sleep(0.01) + self.adev.regGRBM_SOFT_RESET.write(0x0) + class AM_PSP(AM_IP): def __init__(self, adev): super().__init__(adev)