mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
HIP backend (#750)
* llama works for HIP backend * Use hipMemcpyAsync; Less lines of code * Remove unused code * Refactor * Add comments; hipDeviceSynchronize * HIP over GPU; Remove PyHIP dependency * Cleanups * Fix mypy check * Merge master; Dump assembly code
This commit is contained in:
@@ -18,7 +18,7 @@ class TinyJit:
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
|
||||
if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA", "HIP"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
|
||||
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
|
||||
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(RawBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||
|
||||
@@ -319,7 +319,7 @@ class _Device:
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def _get_device(self, x:str) -> Union[Interpreted, Compiled]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
|
||||
def _default_device(self) -> str:
|
||||
for device in ["METAL", "CUDA", "GPU"]:
|
||||
for device in ["METAL", "CUDA", "HIP", "GPU"]:
|
||||
try:
|
||||
if self[device]: return device
|
||||
except Exception: pass
|
||||
|
||||
65
tinygrad/runtime/ops_hip.py
Normal file
65
tinygrad/runtime/ops_hip.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import numpy as np
|
||||
import ctypes
|
||||
import extra.hip_wrapper as hip
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut
|
||||
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
|
||||
|
||||
# TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait()
|
||||
if DEBUG >= 5:
|
||||
from extra.helpers import enable_early_exec
|
||||
early_exec = enable_early_exec()
|
||||
|
||||
# The default HIP stream is used for everything.
|
||||
|
||||
class RawHIPBuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size, dtype):
|
||||
self.buf_sz = size * dtype.itemsize
|
||||
super().__init__(size, dtype, hip.hipMalloc(self.buf_sz))
|
||||
def _copyin(self, x:np.ndarray): hip.hipMemcpyAsync_htod(self._buf, x.ctypes.data, self.buf_sz, 0)
|
||||
def _copyout(self, x:np.ndarray): hip.hipMemcpyAsync_dtoh(x.ctypes.data, self._buf, self.buf_sz, 0)
|
||||
|
||||
class HIPProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
try:
|
||||
if not binary:
|
||||
prog = hip.hiprtcCreateProgram(prg, name, [], [])
|
||||
device_properties = hip.hipGetDeviceProperties(0)
|
||||
hip.hiprtcCompileProgram(prog, [f'--offload-arch={device_properties.gcnArchName}'])
|
||||
prg = hip.hiprtcGetCode(prog)
|
||||
except Exception as e:
|
||||
if DEBUG >= 3: print("FAILED TO BUILD", prg)
|
||||
raise e
|
||||
if DEBUG >= 5:
|
||||
asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg))
|
||||
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
|
||||
|
||||
module = hip.hipModuleLoadData(prg)
|
||||
self.prg = hip.hipModuleGetFunction(module, name)
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False):
|
||||
local_size = (local_size + [1] * (3 - len(local_size))) if local_size is not None else (1,1,1)
|
||||
global_size = global_size + [1] * (3 - len(global_size))
|
||||
assert all(x%y == 0 for x,y in zip(global_size, local_size)), f"local:{local_size} must divide global:{global_size}"
|
||||
global_size = [x//y for x,y in zip(global_size, local_size)]
|
||||
if wait:
|
||||
start, end = hip.hipEventCreate(), hip.hipEventCreate()
|
||||
hip.hipEventRecord(start)
|
||||
class PackageStruct(ctypes.Structure):
|
||||
_fields_ = [(f'field{idx}', ctypes.c_void_p) for idx in range(len(args))]
|
||||
struct = PackageStruct(*[data._buf for data in args])
|
||||
hip.hipModuleLaunchKernel(self.prg, global_size[0], global_size[1], global_size[2], local_size[0], local_size[1], local_size[2], 0, 0, struct)
|
||||
if wait:
|
||||
hip.hipEventRecord(end)
|
||||
hip.hipEventSynchronize(end)
|
||||
return hip.hipEventElapsedTime(start, end)*1e-3
|
||||
|
||||
class HIPCodegen(CStyleCodegen):
|
||||
lang = CStyleLanguage(
|
||||
kernel_prefix = "#define INFINITY (__builtin_inff())\nextern \"C\" __global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
|
||||
half_prekernel = "",
|
||||
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)],
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])
|
||||
|
||||
HIPBuffer = Compiled(RawHIPBuffer, HIPCodegen, HIPProgram, hip.hipDeviceSynchronize)
|
||||
Reference in New Issue
Block a user