mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
bind for sdma
This commit is contained in:
@@ -199,24 +199,45 @@ class AMDCopyQueue(HWQueue):
|
||||
*data64_le(signal.timestamp_addr))
|
||||
return self
|
||||
|
||||
def bind(self, dev:AMDDevice):
|
||||
if not isinstance(dev.dev_iface, VFIOIface): return
|
||||
|
||||
self.binded_device = dev
|
||||
self.hw_page = dev.allocator.alloc((qsz:=round_up(len(self._q), 8)) * 4, BufferSpec(cpu_access=True, nolru=True, uncached=True))
|
||||
ctypes.memset(self.hw_page.cpu_addr, 0x0, self.hw_page.size)
|
||||
hw_view = to_mv(self.hw_page.cpu_addr, self.hw_page.size).cast("I")
|
||||
for i, value in enumerate(self._q): hw_view[i] = value
|
||||
|
||||
self.indirect_cmd = [amd_gpu.SDMA_OP_INDIRECT | amd_gpu.SDMA_PKT_INDIRECT_HEADER_VMID(0), *data64_le(self.hw_page.va_addr), qsz, *data64_le(0)]
|
||||
self._q, self.cmd_sizes = hw_view, [len(self.indirect_cmd)]
|
||||
|
||||
def _submit(self, dev:AMDDevice):
|
||||
if dev.sdma_queue.put_value - dev.sdma_queue.read_ptr[0] > dev.sdma_queue.ring.nbytes: raise RuntimeError("SDMA queue overrun")
|
||||
|
||||
if self.binded_device == dev:
|
||||
# An IB packet must end on a 8 DW boundary.
|
||||
add = (8 - (((dev.sdma_queue.put_value % 32) // 4) + len(self.indirect_cmd) % 8)) % 8
|
||||
cmds, cmd_sizes = ([0] * add) + self.indirect_cmd, [len(self.indirect_cmd) + add]
|
||||
|
||||
if len(cmds) * 4 >= (dev.sdma_queue.ring.nbytes - dev.sdma_queue.put_value % dev.sdma_queue.ring.nbytes):
|
||||
cmds, cmd_sizes = [0, 0] + self.indirect_cmd, [8]
|
||||
else: cmds, cmd_sizes = self._q, self.internal_cmd_sizes
|
||||
|
||||
tail_blit_dword = 0
|
||||
for cmdsz in self.internal_cmd_sizes:
|
||||
for cmdsz in cmd_sizes:
|
||||
if (tail_blit_dword + cmdsz) * 4 >= dev.sdma_queue.ring.nbytes - dev.sdma_queue.put_value % dev.sdma_queue.ring.nbytes: break
|
||||
tail_blit_dword += cmdsz
|
||||
|
||||
start_idx = (dev.sdma_queue.put_value % dev.sdma_queue.ring.nbytes) // 4
|
||||
dev.sdma_queue.ring[start_idx : start_idx + tail_blit_dword] = array.array('I', self._q[:tail_blit_dword])
|
||||
dev.sdma_queue.ring[start_idx : start_idx + tail_blit_dword] = array.array('I', cmds[:tail_blit_dword])
|
||||
dev.sdma_queue.put_value += tail_blit_dword * 4
|
||||
|
||||
if (rem_packet_cnt := len(self._q) - tail_blit_dword) > 0:
|
||||
if (rem_packet_cnt := len(cmds) - tail_blit_dword) > 0:
|
||||
zero_fill = dev.sdma_queue.ring.nbytes - dev.sdma_queue.put_value % dev.sdma_queue.ring.nbytes
|
||||
ctypes.memset(mv_address(dev.sdma_queue.ring) + (dev.sdma_queue.put_value % dev.sdma_queue.ring.nbytes), 0, zero_fill)
|
||||
dev.sdma_queue.put_value += zero_fill
|
||||
|
||||
dev.sdma_queue.ring[0:rem_packet_cnt] = array.array('I', self._q[tail_blit_dword:])
|
||||
dev.sdma_queue.ring[0:rem_packet_cnt] = array.array('I', cmds[tail_blit_dword:])
|
||||
dev.sdma_queue.put_value += rem_packet_cnt * 4
|
||||
|
||||
dev.sdma_queue.write_ptr[0] = dev.sdma_queue.put_value
|
||||
@@ -435,7 +456,7 @@ class VFIOIface:
|
||||
# vfio.VFIO_DEVICE_SET_IRQS(self.vfio_dev, irqs)
|
||||
else: libpciaccess.pci_device_enable(ctypes.byref(self.pcidev))
|
||||
|
||||
self.adev = AMDev(self.pcidev, self._map_pci_range(0), dbell:=self._map_pci_range(2).cast('Q'), self._map_pci_range(5).cast('I'))
|
||||
self.adev = AMDev(self.pcidev, self._map_pci_range(0), dbell:=self._map_pci_range(2), self._map_pci_range(5))
|
||||
self.doorbell_cpu_addr = mv_address(dbell)
|
||||
|
||||
# TODO: think of a way to handle this
|
||||
|
||||
@@ -6,7 +6,6 @@ from tinygrad.runtime.autogen.am import am, mp_11_0, mp_13_0_0, nbio_4_3_0, mmhu
|
||||
from tinygrad.runtime.support.am.mm import MM, GPUPhysicalMemoryBlock
|
||||
from tinygrad.runtime.support.am.firmware import Firmware
|
||||
from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
|
||||
# from tinygrad.runtime.support.am.hal import HAL, PCIHAL, VFIOHAL, read_pagemap
|
||||
|
||||
AM_DEBUG = getenv("AM_DEBUG", 0)
|
||||
|
||||
@@ -35,17 +34,9 @@ class AMRegister:
|
||||
def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0]
|
||||
|
||||
class AMDev:
|
||||
# hal:Optional[HAL] = None
|
||||
|
||||
def __init__(self, pcidev, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
|
||||
# if AMDev.hal is None: AMDev.hal = VFIOHAL()
|
||||
# self.hal_dev = AMDev.hal.open_device(dev_idx)
|
||||
self.pcidev = pcidev
|
||||
self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
|
||||
|
||||
# self.vram_cpu_addr, self.vram = AMDev.hal.map_pci_range(self.hal_dev, bar=0, cast='B')
|
||||
# self.doorbell_cpu_addr, self.doorbell64 = AMDev.hal.map_pci_range(self.hal_dev, bar=2, cast='Q')
|
||||
# self.mmio_cpu_addr, self.mmio = AMDev.hal.map_pci_range(self.hal_dev, bar=5, cast='I')
|
||||
self.vram, self.doorbell64, self.mmio = vram_bar.cast('B'), doorbell_bar.cast('Q'), mmio_bar.cast('I')
|
||||
|
||||
self._run_discovery()
|
||||
self._build_regs()
|
||||
@@ -70,7 +61,6 @@ class AMDev:
|
||||
|
||||
self.soc21.init()
|
||||
self.gmc.init()
|
||||
# self.regRLC_SPM_MC_CNTL.write(0xf)
|
||||
self.ih.init()
|
||||
self.psp.init()
|
||||
self.smu.init()
|
||||
|
||||
@@ -77,9 +77,7 @@ class Firmware:
|
||||
self.sos_fw = {}
|
||||
|
||||
blob, sos_hdr = load_fw(self.SOS_PATH, am.struct_psp_firmware_header_v2_0)
|
||||
fw_bin = sos_hdr.psp_fw_bin
|
||||
|
||||
for fw_i in range(sos_hdr.psp_fw_bin_count):
|
||||
fw_bin_desc = am.struct_psp_fw_bin_desc.from_address(ctypes.addressof(fw_bin) + fw_i * ctypes.sizeof(am.struct_psp_fw_bin_desc))
|
||||
fw_bin_desc = am.struct_psp_fw_bin_desc.from_address(ctypes.addressof(sos_hdr.psp_fw_bin) + fw_i * ctypes.sizeof(am.struct_psp_fw_bin_desc))
|
||||
ucode_start_offset = fw_bin_desc.offset_bytes + sos_hdr.header.ucode_array_offset_bytes
|
||||
self.sos_fw[fw_bin_desc.fw_type] = blob[ucode_start_offset:ucode_start_offset+fw_bin_desc.size_bytes]
|
||||
|
||||
Reference in New Issue
Block a user