mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 17:15:48 -05:00
hip launch speed (#3246)
* faster HIP kernel launch * args * expand compile_hip
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
import ctypes
|
||||
from typing import Tuple
|
||||
import gpuctypes.hip as hip
|
||||
from tinygrad.helpers import init_c_var
|
||||
from tinygrad.runtime.ops_hip import check, hip_time_execution
|
||||
from tinygrad.helpers import init_c_var, time_execution_cuda_style
|
||||
from tinygrad.runtime.ops_hip import check
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
|
||||
# TODO: this is only used in graph
|
||||
def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) # noqa: E501
|
||||
|
||||
class HIPGraph(CUDAGraph):
|
||||
def __del__(self):
|
||||
if hasattr(self, 'graph'): check(hip.hipGraphDestroy(self.graph))
|
||||
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
import ctypes, functools, subprocess, io
|
||||
from typing import Tuple, TypeVar, List, Any, cast, Set
|
||||
import gpuctypes.hip as hip
|
||||
from tinygrad.helpers import DEBUG, getenv, init_c_var, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style
|
||||
from tinygrad.helpers import from_mv, round_up, to_mv, colored
|
||||
from tinygrad.helpers import DEBUG, getenv, init_c_var
|
||||
from tinygrad.helpers import from_mv, round_up, to_mv, colored, init_c_struct_t, to_char_p_p, get_bytes
|
||||
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator, BufferOptions, JITRunner, Device, Buffer, update_stats
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
@@ -11,13 +11,22 @@ from tinygrad.codegen.kernel import LinearizerOptions
|
||||
# The default HIP stream is used for everything.
|
||||
MOCKHIP = getenv("MOCKHIP") # for CI. don't run kernels, only check if they compile
|
||||
|
||||
hip_current_device = None
|
||||
def hip_set_device(d:int):
|
||||
global hip_current_device
|
||||
if d == hip_current_device: return
|
||||
check(hip.hipSetDevice(d))
|
||||
hip_current_device = d
|
||||
|
||||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
|
||||
|
||||
# TODO: remove these helpers, they increase complexity
|
||||
def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) # noqa: E501
|
||||
|
||||
def compile_hip(prg:str, arch="gfx1100") -> bytes: return compile_cuda_style(prg, [f'--offload-arch={arch}', '-I/opt/rocm/include'], hip.hiprtcProgram, hip.hiprtcCreateProgram, hip.hiprtcCompileProgram, hip.hiprtcGetCode, hip.hiprtcGetCodeSize, hip.hiprtcGetProgramLog, hip.hiprtcGetProgramLogSize, check) # noqa: E501
|
||||
def compile_hip(prg:str, arch="gfx1100") -> bytes:
|
||||
check(hip.hiprtcCreateProgram(ctypes.byref(prog := hip.hiprtcProgram()), prg.encode(), "<null>".encode(), 0, None, None))
|
||||
compile_options = [f'--offload-arch={arch}', '-I/opt/rocm/include']
|
||||
status = hip.hiprtcCompileProgram(prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options]))
|
||||
if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, hip.hiprtcGetProgramLogSize, hip.hiprtcGetProgramLog, check).decode()}")
|
||||
return get_bytes(prog, hip.hiprtcGetCodeSize, hip.hiprtcGetCode, check)
|
||||
|
||||
class HIPProgram:
|
||||
def __init__(self, device:int, name:str, lib:bytes):
|
||||
@@ -28,7 +37,7 @@ class HIPProgram:
|
||||
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
|
||||
|
||||
if MOCKHIP: return
|
||||
check(hip.hipSetDevice(self.device))
|
||||
hip_set_device(self.device)
|
||||
self.module = init_c_var(hip.hipModule_t(), lambda x: check(hip.hipModuleLoadData(ctypes.byref(x), lib)))
|
||||
self.prg = init_c_var(hip.hipFunction_t(), lambda x: check(hip.hipModuleGetFunction(ctypes.byref(x), self.module, name.encode("utf-8"))))
|
||||
|
||||
@@ -37,8 +46,27 @@ class HIPProgram:
|
||||
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if MOCKHIP: return float("inf")
|
||||
check(hip.hipSetDevice(self.device))
|
||||
return hip_time_execution(lambda: check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, encode_args_cuda_style(args, vals, hip.hipDeviceptr_t, marks=(1,2,3))[0])), enable=wait) # noqa: E501
|
||||
hip_set_device(self.device)
|
||||
if not hasattr(self, "vargs"):
|
||||
self.c_args = init_c_struct_t(tuple([(f'f{i}', hip.hipDeviceptr_t) for i in range(len(args))] +
|
||||
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))(*args, *vals)
|
||||
self.vargs = (ctypes.c_void_p * 5)(ctypes.c_void_p(1), ctypes.cast(ctypes.byref(self.c_args), ctypes.c_void_p),
|
||||
ctypes.c_void_p(2), ctypes.cast(ctypes.byref(ctypes.c_size_t(ctypes.sizeof(self.c_args))), ctypes.c_void_p),
|
||||
ctypes.c_void_p(3))
|
||||
else:
|
||||
for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i])
|
||||
for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i])
|
||||
if wait:
|
||||
evs = [init_c_var(hip.hipEvent_t(), lambda x: hip.hipEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
|
||||
check(hip.hipEventRecord(evs[0], None))
|
||||
check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, self.vargs))
|
||||
if wait:
|
||||
check(hip.hipEventRecord(evs[1], None))
|
||||
check(hip.hipEventSynchronize(evs[1]))
|
||||
check(hip.hipEventElapsedTime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1]))
|
||||
for ev in evs: check(hip.hipEventDestroy(ev))
|
||||
return ret.value * 1e-3
|
||||
return None
|
||||
|
||||
T = TypeVar("T")
|
||||
CHUNK_SIZE, PAGE_SIZE = 256*1024*1024, 0x1000
|
||||
@@ -52,10 +80,10 @@ class HIPAllocator(LRUAllocator):
|
||||
for x in self.track_cross_device: x.synchronize()
|
||||
return super().free_cache()
|
||||
def _alloc(self, size:int):
|
||||
check(hip.hipSetDevice(self.device.device))
|
||||
hip_set_device(self.device.device)
|
||||
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
|
||||
def _alloc_with_options(self, size:int, options:BufferOptions):
|
||||
check(hip.hipSetDevice(self.device.device))
|
||||
hip_set_device(self.device.device)
|
||||
if options.uncached:
|
||||
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipExtMallocWithFlags(ctypes.byref(x), size, 3))) # hipDeviceMallocUncached = 3
|
||||
elif options.host:
|
||||
@@ -64,7 +92,7 @@ class HIPAllocator(LRUAllocator):
|
||||
raise Exception("no options")
|
||||
def _free(self, opaque:T): check(hip.hipFree(opaque))
|
||||
def copy_from_fd(self, dest, fd, offset, size):
|
||||
check(hip.hipSetDevice(self.device.device))
|
||||
hip_set_device(self.device.device)
|
||||
if not hasattr(self, 'hb'):
|
||||
self.hb = [self._alloc_with_options(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)]
|
||||
self.hb_events = [None, None]
|
||||
@@ -88,17 +116,17 @@ class HIPAllocator(LRUAllocator):
|
||||
self.hb_polarity = (self.hb_polarity+1) % len(self.hb)
|
||||
minor_offset = 0 # only on the first
|
||||
def copyin(self, dest:T, src: memoryview):
|
||||
check(hip.hipSetDevice(self.device.device))
|
||||
hip_set_device(self.device.device)
|
||||
host_mem = self._alloc_with_options(len(src), BufferOptions(host=True))
|
||||
self.device.pending_copyin.append(host_mem)
|
||||
ctypes.memmove(host_mem, from_mv(src), len(src))
|
||||
check(hip.hipMemcpyAsync(dest, host_mem, len(src), hip.hipMemcpyHostToDevice, None))
|
||||
def copyout(self, dest:memoryview, src:T):
|
||||
self.device.synchronize()
|
||||
check(hip.hipSetDevice(self.device.device))
|
||||
hip_set_device(self.device.device)
|
||||
check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
|
||||
def transfer(self, dest:T, src:T, sz:int):
|
||||
check(hip.hipSetDevice(self.device.device))
|
||||
hip_set_device(self.device.device)
|
||||
check(hip.hipMemcpyAsync(dest, src, sz, hip.hipMemcpyDeviceToDevice, None))
|
||||
|
||||
class HIPDevice(Compiled):
|
||||
@@ -113,14 +141,14 @@ class HIPDevice(Compiled):
|
||||
super().__init__(device, MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer,
|
||||
functools.partial(compile_hip,arch=self.arch), f"compile_hip_{self.arch}", functools.partial(HIPProgram, self.device), HIPGraph)
|
||||
def synchronize(self):
|
||||
check(hip.hipSetDevice(self.device))
|
||||
hip_set_device(self.device)
|
||||
check(hip.hipDeviceSynchronize())
|
||||
for opaque in self.pending_copyin: check(hip.hipFree(opaque))
|
||||
self.track_cross_buffer.clear()
|
||||
self.pending_copyin.clear()
|
||||
def enable_peer(self, dnum):
|
||||
if self.device == dnum or dnum in self.peers: return
|
||||
check(hip.hipSetDevice(self.device))
|
||||
hip_set_device(self.device)
|
||||
check(hip.hipDeviceEnablePeerAccess(dnum, 0))
|
||||
self.peers.add(dnum)
|
||||
|
||||
@@ -130,7 +158,7 @@ class HIPSyncEvent(JITRunner):
|
||||
super().__init__()
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
|
||||
to_mv(rawbufs[0]._buf, 4).cast("I")[0] = 0
|
||||
check(hip.hipSetDevice(self.device.device))
|
||||
hip_set_device(self.device.device)
|
||||
check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0))
|
||||
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, device=self.dname)
|
||||
|
||||
@@ -139,6 +167,6 @@ class HIPWaitEvent(JITRunner):
|
||||
self.device, self.dname = cast(HIPDevice, Device[device]), device
|
||||
super().__init__()
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
|
||||
check(hip.hipSetDevice(self.device.device))
|
||||
hip_set_device(self.device.device)
|
||||
check(hip.hipStreamWaitValue32(None, rawbufs[0]._buf, 1, 1, 0xFFFFFFFF))
|
||||
update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, device=self.dname)
|
||||
|
||||
Reference in New Issue
Block a user