multigpu works (#4040)

This commit is contained in:
George Hotz
2024-04-02 08:29:37 -07:00
committed by GitHub
parent 05e7f930ee
commit 506b1c5892

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Tuple
from typing import Tuple, Any
import os, fcntl, ctypes, functools, re, pathlib, mmap, struct, errno
from tinygrad.device import Compiled, LRUAllocator, Compiler, BufferOptions, CompilerOptions
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up
@@ -187,12 +187,18 @@ class KFDAllocator(LRUAllocator):
MAP_FIXED, MAP_NORESERVE = 0x10, 0x400
class KFDDevice(Compiled):
kfd:int = -1
event_page:Any = None # TODO: fix types in kfd, Optional[kfd.struct_kfd_ioctl_alloc_memory_of_gpu_args]
def _map_userptr_to_gpu(self, addr, size):
self.map_uptr2gpu_struct.start_addr = addr&~0xfff
self.map_uptr2gpu_struct.size = round_up(size+addr-(addr&~0xfff), 0x1000)
kio.svm(self.kfd, made_struct=self.map_uptr2gpu_struct)
def _gpu_map(self, mem):
mem.__setattr__("mapped_gpu_ids", (ctypes.c_int32 * 1)(self.gpu_id))
stm = kio.map_memory_to_gpu(self.kfd, handle=mem.handle, device_ids_array_ptr=ctypes.addressof(gpus:=mem.mapped_gpu_ids), n_devices=len(gpus))
assert stm.n_success == 1
def _gpu_alloc(self, size:int, flags:int, uncached=False, public=False, map_to_gpu=True):
flags |= kfd.KFD_IOC_ALLOC_MEM_FLAGS_WRITABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_EXECUTABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_NO_SUBSTITUTE
if uncached: flags |= kfd.KFD_IOC_ALLOC_MEM_FLAGS_COHERENT | kfd.KFD_IOC_ALLOC_MEM_FLAGS_UNCACHED
@@ -207,10 +213,7 @@ class KFDDevice(Compiled):
buf = libc.mmap(mem.va_addr, mem.size, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED|MAP_FIXED, self.drm_fd, mem.mmap_offset)
assert buf != 0xffffffffffffffff
assert addr == buf == mem.va_addr
if map_to_gpu:
mem.__setattr__("mapped_gpu_ids", (ctypes.c_int32 * 1)(self.gpu_id))
stm = kio.map_memory_to_gpu(self.kfd, handle=mem.handle, device_ids_array_ptr=ctypes.addressof(gpus:=mem.mapped_gpu_ids), n_devices=len(gpus))
assert stm.n_success == 1
if map_to_gpu: self._gpu_map(mem)
return mem
def _gpu_free(self, mem):
@@ -229,20 +232,26 @@ class KFDDevice(Compiled):
self.arch = f"gfx{self.properties['gfx_target_version']//100}"
kio.acquire_vm(KFDDevice.kfd, drm_fd=self.drm_fd, gpu_id=self.gpu_id)
self.event_page = self._gpu_alloc(0x8000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
self.sync_event = kio.create_event(KFDDevice.kfd, event_page_offset=self.event_page.handle, auto_reset=1)
self.eop_buffer = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
if KFDDevice.event_page is None:
KFDDevice.event_page = self._gpu_alloc(0x8000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
self.sync_event = kio.create_event(KFDDevice.kfd, event_page_offset=KFDDevice.event_page.handle, auto_reset=1)
else:
self._gpu_map(KFDDevice.event_page)
self.sync_event = kio.create_event(KFDDevice.kfd, auto_reset=1)
self.gart = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
self.aql_ring = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR, uncached=True)
self.signals_page = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR, uncached=True)
self.gart = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
self.kernargs = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
self.pm4_indirect_buf = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR, uncached=True)
self.eop_buffer = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
self.kernargs = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
self.ctx_save_restore_address = self._gpu_alloc(0x2C02000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
self.completion_signal = hsa.amd_signal_t.from_address(self.signals_page.va_addr)
self.completion_signal.value = 1
self.completion_signal.kind = hsa.AMD_SIGNAL_KIND_USER
self.completion_signal.event_mailbox_ptr = self.event_page.va_addr + self.sync_event.event_slot_index*8
self.completion_signal.event_mailbox_ptr = KFDDevice.event_page.va_addr + self.sync_event.event_slot_index*8
self.completion_signal.event_id = self.sync_event.event_id
# AQL Queue