mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, cast
|
||||
import os, ctypes, ctypes.util, functools, mmap, errno, array, contextlib, sys, select, struct
|
||||
import os, ctypes, ctypes.util, functools, mmap, errno, array, contextlib, sys, select, struct, atexit
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQSignal, HCQProgram, HWInterface
|
||||
@@ -524,6 +524,8 @@ class PCIIface:
|
||||
for d in self.dev.devices: d.dev_iface.adev.gmc.on_interrupt()
|
||||
raise RuntimeError("Device hang detected")
|
||||
|
||||
def device_fini(self): self.adev.fini()
|
||||
|
||||
class AMDDevice(HCQCompiled):
|
||||
driverless:bool = not HWInterface.exists('/sys/module/amdgpu') or bool(getenv("AMD_DRIVERLESS", 0))
|
||||
signals_page:Any = None
|
||||
@@ -570,6 +572,7 @@ class AMDDevice(HCQCompiled):
|
||||
|
||||
super().__init__(device, AMDAllocator(self), AMDRenderer(), AMDCompiler(self.arch), functools.partial(AMDProgram, self),
|
||||
AMDSignal, AMDComputeQueue, AMDCopyQueue)
|
||||
atexit.register(self.device_fini)
|
||||
|
||||
def create_queue(self, queue_type, ring_size, ctx_save_restore_size=0, eop_buffer_size=0, ctl_stack_size=0, debug_memory_size=0):
|
||||
ring = self.dev_iface.alloc(ring_size, uncached=True, cpu_access=True)
|
||||
@@ -584,3 +587,7 @@ class AMDDevice(HCQCompiled):
|
||||
self.synchronize()
|
||||
|
||||
def on_device_hang(self): self.dev_iface.on_device_hang()
|
||||
|
||||
def device_fini(self):
|
||||
self.synchronize()
|
||||
if hasattr(self.dev_iface, 'device_fini'): self.dev_iface.device_fini()
|
||||
|
||||
@@ -275,6 +275,9 @@ class AMDev:
|
||||
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu, self.gfx, self.sdma]: ip.init()
|
||||
self.gfx.set_clockgating_state()
|
||||
|
||||
def fini(self):
|
||||
for ip in [self.sdma, self.gfx]: ip.fini()
|
||||
|
||||
def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg]
|
||||
|
||||
def reg(self, reg:str) -> AMRegister: return self.__dict__[reg]
|
||||
|
||||
@@ -7,6 +7,7 @@ from tinygrad.helpers import to_mv, data64, lo32, hi32
|
||||
class AM_IP:
|
||||
def __init__(self, adev): self.adev = adev
|
||||
def init(self): raise NotImplementedError("IP block init must be implemeted")
|
||||
def fini(self): pass
|
||||
|
||||
class AM_SOC21(AM_IP):
|
||||
def init(self):
|
||||
@@ -158,6 +159,15 @@ class AM_GFX(AM_IP):
|
||||
# NOTE: Wait for MEC to be ready. The kernel does udelay here as well.
|
||||
time.sleep(0.05)
|
||||
|
||||
def fini(self):
|
||||
self._grbm_select(me=1, pipe=0, queue=0)
|
||||
self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
|
||||
self.adev.regSPI_COMPUTE_QUEUE_RESET.write(1)
|
||||
self._grbm_select()
|
||||
self.adev.regCP_MEC_RS64_CNTL.update(mec_invalidate_icache=1, mec_pipe0_reset=1, mec_pipe1_reset=1, mec_pipe2_reset=1, mec_pipe3_reset=1,
|
||||
mec_pipe0_active=0, mec_pipe1_active=0, mec_pipe2_active=0, mec_pipe3_active=0, mec_halt=1)
|
||||
self.adev.regGCVM_CONTEXT0_CNTL.write(0)
|
||||
|
||||
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int):
|
||||
mqd = self.adev.mm.valloc(0x1000, uncached=True, contigous=True)
|
||||
|
||||
@@ -260,10 +270,6 @@ class AM_IH(AM_IP):
|
||||
|
||||
class AM_SDMA(AM_IP):
|
||||
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int):
|
||||
# Stop if something is running...
|
||||
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_RB_CNTL").update(rb_enable=0)
|
||||
while not self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_CONTEXT_STATUS").read(idle=1): pass
|
||||
|
||||
# Setup the ring
|
||||
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_RPTR", "", "_HI", 0)
|
||||
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_WPTR", "", "_HI", 0)
|
||||
@@ -284,6 +290,14 @@ class AM_SDMA(AM_IP):
|
||||
self.adev.regSDMA0_F32_CNTL.update(halt=0, th1_reset=0)
|
||||
self.adev.regSDMA0_CNTL.update(ctxempty_int_enable=1, trap_enable=1)
|
||||
|
||||
def fini(self):
|
||||
self.adev.regSDMA0_QUEUE0_RB_CNTL.update(rb_enable=0)
|
||||
self.adev.regSDMA0_QUEUE0_IB_CNTL.update(ib_enable=0)
|
||||
self.adev.regSDMA0_F32_CNTL.update(halt=1, th1_reset=1)
|
||||
self.adev.regGRBM_SOFT_RESET.write(soft_reset_sdma0=1)
|
||||
time.sleep(0.01)
|
||||
self.adev.regGRBM_SOFT_RESET.write(0x0)
|
||||
|
||||
class AM_PSP(AM_IP):
|
||||
def __init__(self, adev):
|
||||
super().__init__(adev)
|
||||
|
||||
Reference in New Issue
Block a user