mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
hip & cuda to gpuctypes (#2539)
* cuda with gpuctypes * hip gpuctypes * graphs * rename + linter happy * use cpu_time_execution * no ji in build_kernel_node_params * remove hip_wrapper * hip fix * no arc * smalle changes * no clean moduke in cudacpu
This commit is contained in:
21
extra/dist/world.py
vendored
21
extra/dist/world.py
vendored
@@ -1,23 +1,24 @@
|
||||
from typing import Optional
|
||||
import ctypes
|
||||
from extra import dist
|
||||
from multiprocessing import shared_memory
|
||||
from tinygrad.helpers import DEBUG, colored, getenv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut
|
||||
try: from tinygrad.runtime.ops_hip import RawHIPBuffer
|
||||
try:
|
||||
import gpuctypes.hip as hip
|
||||
from tinygrad.runtime.ops_hip import RawHIPBuffer, check
|
||||
except: RawHIPBuffer = None
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.tensor import Tensor, Function
|
||||
import extra.hip_wrapper as hip
|
||||
import numpy as np
|
||||
|
||||
# match the function signature of JITRunner so we can put it in the cache
|
||||
def __send_rb(args, variables=None, wait=False, jit=False):
|
||||
x, target_rank, y = args[:3]
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
hip.hipSetDevice(x._device)
|
||||
hip.hipDeviceSynchronize()
|
||||
check(hip.hipSetDevice(x._device))
|
||||
check(hip.hipDeviceSynchronize())
|
||||
else:
|
||||
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||
else: y.fromCPU(x.toCPU())
|
||||
@@ -37,9 +38,9 @@ def __recv_rb(args, variables=None, wait=False, jit=False):
|
||||
def _send_rb(x:RawBuffer, target_rank:int):
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
# send ipc handle
|
||||
hip.hipSetDevice(x._device)
|
||||
hip.hipDeviceSynchronize()
|
||||
handle = hip.hipIpcGetMemHandle(x._buf)
|
||||
check(hip.hipSetDevice(x._device))
|
||||
check(hip.hipDeviceSynchronize())
|
||||
check(hip.hipIpcGetMemHandle(ctypes.byval(handle := hip.hipIpcMemHandle_t()), x._buf))
|
||||
dist.OOB.send((handle, x._device), target_rank)
|
||||
|
||||
# jit support
|
||||
@@ -67,8 +68,8 @@ def _recv_rb(x:RawBuffer, target_rank:int):
|
||||
if RawHIPBuffer and isinstance(x, RawHIPBuffer):
|
||||
# open ipc handle
|
||||
handle, y_device = dist.OOB.recv(target_rank)
|
||||
hip.hipSetDevice(y_device)
|
||||
ptr = hip.hipIpcOpenMemHandle(handle, 0)
|
||||
check(hip.hipSetDevice(y_device))
|
||||
check(hip.hipIpcOpenMemHandle(ctypes.byval(ptr := ctypes.c_void_p()), handle, 0))
|
||||
|
||||
# build a new buffer
|
||||
y = RawHIPBuffer(x.size, x.dtype, device=str(y_device), buf=ptr, allocator=None)
|
||||
|
||||
@@ -1,682 +0,0 @@
|
||||
import ctypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
try:
|
||||
_libhip = ctypes.cdll.LoadLibrary("/opt/rocm/lib/libamdhip64.so")
|
||||
_libhiprtc = ctypes.cdll.LoadLibrary("/opt/rocm/lib/libhiprtc.so")
|
||||
|
||||
_libhip.hipGetErrorString.restype = ctypes.c_char_p
|
||||
_libhip.hipGetErrorString.argtypes = [ctypes.c_int]
|
||||
def hipGetErrorString(status):
|
||||
return _libhip.hipGetErrorString(status).decode("utf-8")
|
||||
|
||||
def hipCheckStatus(status):
|
||||
if status != 0: raise RuntimeError("HIP error %s: %s" % (status, hipGetErrorString(status)))
|
||||
|
||||
_libhip.hipDeviceSynchronize.restype = int
|
||||
_libhip.hipDeviceSynchronize.argtypes = []
|
||||
def hipDeviceSynchronize():
|
||||
status = _libhip.hipDeviceSynchronize()
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipStreamSynchronize.restype = int
|
||||
_libhip.hipStreamSynchronize.argtypes = [ctypes.c_void_p]
|
||||
def hipStreamSynchronize(stream):
|
||||
status = _libhip.hipStreamSynchronize(stream)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipEventCreate.restype = int
|
||||
_libhip.hipEventCreate.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
|
||||
def hipEventCreate():
|
||||
ptr = ctypes.c_void_p()
|
||||
status = _libhip.hipEventCreate(ctypes.byref(ptr))
|
||||
hipCheckStatus(status)
|
||||
return ptr
|
||||
|
||||
_libhip.hipEventRecord.restype = int
|
||||
_libhip.hipEventRecord.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
||||
def hipEventRecord(event, stream=None):
|
||||
status = _libhip.hipEventRecord(event, stream)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipEventDestroy.restype = int
|
||||
_libhip.hipEventDestroy.argtypes = [ctypes.c_void_p]
|
||||
def hipEventDestroy(event):
|
||||
status = _libhip.hipEventDestroy(event)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipEventSynchronize.restype = int
|
||||
_libhip.hipEventSynchronize.argtypes = [ctypes.c_void_p]
|
||||
def hipEventSynchronize(event):
|
||||
status = _libhip.hipEventSynchronize(event)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipEventElapsedTime.restype = int
|
||||
_libhip.hipEventElapsedTime.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_void_p, ctypes.c_void_p]
|
||||
def hipEventElapsedTime(start, stop):
|
||||
t = ctypes.c_float()
|
||||
status = _libhip.hipEventElapsedTime(ctypes.byref(t), start, stop)
|
||||
hipCheckStatus(status)
|
||||
return t.value
|
||||
|
||||
## Stream Management
|
||||
|
||||
# Stream capture modes:
|
||||
hipStreamCaptureModeGlobal = 0
|
||||
hipStreamCaptureModeThreadLocal = 1
|
||||
hipStreamCaptureModeRelaxed = 2
|
||||
|
||||
_libhip.hipStreamCreate.restype = int
|
||||
_libhip.hipStreamCreate.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
|
||||
def hipStreamCreate():
|
||||
ptr = ctypes.c_void_p()
|
||||
status = _libhip.hipStreamCreate(ctypes.byref(ptr))
|
||||
hipCheckStatus(status)
|
||||
return ptr
|
||||
|
||||
_libhip.hipStreamDestroy.restype = int
|
||||
_libhip.hipStreamDestroy.argtypes = [ctypes.c_void_p]
|
||||
def hipStreamDestroy(stream):
|
||||
status = _libhip.hipStreamDestroy(stream)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipStreamBeginCapture.restype = int
|
||||
_libhip.hipStreamBeginCapture.argtypes = [ctypes.c_void_p, ctypes.c_int]
|
||||
def hipStreamBeginCapture(stream, mode=hipStreamCaptureModeGlobal):
|
||||
t = ctypes.c_float()
|
||||
status = _libhip.hipStreamBeginCapture(stream, mode)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipStreamEndCapture.restype = int
|
||||
_libhip.hipStreamEndCapture.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
||||
def hipStreamEndCapture(stream):
|
||||
ptr = ctypes.c_void_p()
|
||||
status = _libhip.hipStreamEndCapture(stream, ctypes.byref(ptr))
|
||||
hipCheckStatus(status)
|
||||
return ptr
|
||||
|
||||
_libhip.hipStreamGetCaptureInfo_v2.restype = int
|
||||
_libhip.hipStreamGetCaptureInfo_v2.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
|
||||
def hipStreamGetCaptureInfo_v2(stream):
|
||||
status_out = ctypes.c_void_p()
|
||||
id_out = ctypes.c_ulonglong()
|
||||
graph_out = ctypes.c_void_p()
|
||||
deps_out = ctypes.POINTER(ctypes.c_void_p)()
|
||||
num_deps = ctypes.c_size_t()
|
||||
status = _libhip.hipStreamGetCaptureInfo_v2(stream, ctypes.byref(status_out), ctypes.byref(id_out), ctypes.byref(graph_out), ctypes.byref(deps_out), ctypes.byref(num_deps))
|
||||
hipCheckStatus(status)
|
||||
deps = [ctypes.cast(deps_out[i], ctypes.c_void_p) for i in range(num_deps.value)]
|
||||
return status_out, id_out.value, graph_out, deps
|
||||
|
||||
hipStreamAddCaptureDependencies = 0
|
||||
hipStreamSetCaptureDependencies = 1
|
||||
|
||||
_libhip.hipStreamUpdateCaptureDependencies.restype = int
|
||||
_libhip.hipStreamUpdateCaptureDependencies.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint]
|
||||
def hipStreamUpdateCaptureDependencies(stream, deps, flags=hipStreamAddCaptureDependencies):
|
||||
deps_in = (ctypes.c_void_p * len(deps))()
|
||||
deps_in[:] = deps
|
||||
num_deps = ctypes.c_size_t()
|
||||
num_deps.value = len(deps)
|
||||
flags_in = ctypes.c_uint()
|
||||
flags_in.value = flags
|
||||
status = _libhip.hipStreamUpdateCaptureDependencies(stream, deps_in, num_deps, flags_in)
|
||||
hipCheckStatus(status)
|
||||
|
||||
|
||||
## Graph Management
|
||||
|
||||
_libhip.hipGraphCreate.restype = int
|
||||
_libhip.hipGraphCreate.argtypes = [ctypes.c_void_p, ctypes.c_uint]
|
||||
def hipGraphCreate():
|
||||
ptr = ctypes.c_void_p()
|
||||
status = _libhip.hipGraphCreate(ctypes.byref(ptr), 0)
|
||||
hipCheckStatus(status)
|
||||
return ptr
|
||||
|
||||
_libhip.hipGraphInstantiate.restype = int
|
||||
_libhip.hipGraphInstantiate.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
|
||||
def hipGraphInstantiate(graph):
|
||||
ptr = ctypes.c_void_p()
|
||||
status = _libhip.hipGraphInstantiate(ctypes.byref(ptr), graph, 0, 0, 0)
|
||||
hipCheckStatus(status)
|
||||
return ptr
|
||||
|
||||
_libhip.hipGraphDestroy.restype = int
|
||||
_libhip.hipGraphDestroy.argtypes = [ctypes.c_void_p]
|
||||
def hipGraphDestroy(graph):
|
||||
status = _libhip.hipGraphDestroy(graph)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipGraphExecDestroy.restype = int
|
||||
_libhip.hipGraphExecDestroy.argtypes = [ctypes.c_void_p]
|
||||
def hipGraphExecDestroy(gexec):
|
||||
status = _libhip.hipGraphExecDestroy(gexec)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipGraphLaunch.restype = int
|
||||
_libhip.hipGraphLaunch.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
||||
def hipGraphLaunch(graph_exec, stream=0):
|
||||
status = _libhip.hipGraphLaunch(graph_exec, stream)
|
||||
hipCheckStatus(status)
|
||||
|
||||
class hipKernelNodeParams(ctypes.Structure):
|
||||
_fields_ = [("blockDimX", ctypes.c_uint32), ("blockDimY", ctypes.c_uint32), ("blockDimZ", ctypes.c_uint32),
|
||||
("extra", ctypes.POINTER(ctypes.c_void_p)),
|
||||
("func", ctypes.c_void_p),
|
||||
("gridDimX", ctypes.c_uint32), ("gridDimY", ctypes.c_uint32), ("gridDimZ", ctypes.c_uint32),
|
||||
("kernelParams", ctypes.POINTER(ctypes.c_void_p)),
|
||||
("sharedMemBytes", ctypes.c_uint32)]
|
||||
|
||||
@dataclass
|
||||
class kernelNodeParamsWrapper():
|
||||
c_struct: Any
|
||||
context: Any = None
|
||||
|
||||
def getCStructForType(argtypes):
|
||||
fields = []
|
||||
for j,typ in enumerate(argtypes):
|
||||
fields.append((f'field{j}', typ))
|
||||
|
||||
class CStructure(ctypes.Structure):
|
||||
_fields_ = fields
|
||||
return CStructure
|
||||
|
||||
def setKernelNodeLaunchDims(npwrapper:kernelNodeParamsWrapper, grid, block):
|
||||
npwrapper.c_struct.blockDimX = block[0]
|
||||
npwrapper.c_struct.blockDimY = block[1]
|
||||
npwrapper.c_struct.blockDimZ = block[2]
|
||||
npwrapper.c_struct.gridDimX = grid[0]
|
||||
npwrapper.c_struct.gridDimY = grid[1]
|
||||
npwrapper.c_struct.gridDimZ = grid[2]
|
||||
|
||||
def setKernelNodeParams(npwrapper:kernelNodeParamsWrapper, args, ids):
|
||||
for j,i in enumerate(ids):
|
||||
setattr(npwrapper.context[1], f'field{i}', args[j])
|
||||
|
||||
def buildKernelNodeParams(args, argtypes, func, grid, block, sharedMemBytes=0):
|
||||
c_struct_t = getCStructForType(argtypes)
|
||||
struct = c_struct_t(*args)
|
||||
size = ctypes.c_size_t(ctypes.sizeof(struct))
|
||||
p_size = ctypes.c_void_p(ctypes.addressof(size))
|
||||
p_struct = ctypes.c_void_p(ctypes.addressof(struct))
|
||||
config = (ctypes.c_void_p * 5)(ctypes.c_void_p(1), p_struct,
|
||||
ctypes.c_void_p(2), p_size, ctypes.c_void_p(3))
|
||||
params = hipKernelNodeParams(block[0], block[1], block[2], config, func, grid[0], grid[1], grid[2], None, sharedMemBytes)
|
||||
return kernelNodeParamsWrapper(c_struct=params, context=(size, struct, config))
|
||||
|
||||
_libhip.hipGraphAddKernelNode.restype = int
|
||||
_libhip.hipGraphAddKernelNode.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_void_p]
|
||||
def hipGraphAddKernelNode(graph, deps, params:kernelNodeParamsWrapper):
|
||||
graph_node = ctypes.c_void_p()
|
||||
deps_in = (ctypes.c_void_p * len(deps))()
|
||||
deps_in[:] = deps
|
||||
num_deps = ctypes.c_size_t(len(deps))
|
||||
status = _libhip.hipGraphAddKernelNode(ctypes.byref(graph_node), graph, deps_in, num_deps, ctypes.byref(params.c_struct))
|
||||
hipCheckStatus(status)
|
||||
return graph_node
|
||||
|
||||
_libhip.hipGraphExecKernelNodeSetParams.restype = int
|
||||
_libhip.hipGraphExecKernelNodeSetParams.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
|
||||
def hipGraphExecKernelNodeSetParams(gexec, node, params:kernelNodeParamsWrapper):
|
||||
status = _libhip.hipGraphExecKernelNodeSetParams(gexec, node, ctypes.byref(params.c_struct))
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipMalloc.restype = int
|
||||
_libhip.hipMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]
|
||||
def hipMalloc(count):
|
||||
ptr = ctypes.c_void_p()
|
||||
status = _libhip.hipMalloc(ctypes.byref(ptr), count)
|
||||
hipCheckStatus(status)
|
||||
return ptr.value
|
||||
|
||||
_libhip.hipFree.restype = int
|
||||
_libhip.hipFree.argtypes = [ctypes.c_void_p]
|
||||
def hipFree(ptr):
|
||||
status = _libhip.hipFree(ptr)
|
||||
hipCheckStatus(status)
|
||||
|
||||
# memory copy modes
|
||||
hipMemcpyHostToHost = 0
|
||||
hipMemcpyHostToDevice = 1
|
||||
hipMemcpyDeviceToHost = 2
|
||||
hipMemcpyDeviceToDevice = 3
|
||||
hipMemcpyDefault = 4
|
||||
|
||||
_libhip.hipHostMalloc.restype = int
|
||||
_libhip.hipHostMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t, ctypes.c_uint32]
|
||||
def hipHostMalloc(count, flags=0):
|
||||
ptr = ctypes.c_void_p()
|
||||
status = _libhip.hipHostMalloc(ctypes.byref(ptr), count, flags)
|
||||
hipCheckStatus(status)
|
||||
return ptr.value
|
||||
|
||||
_libhip.hipMemcpy.restype = int
|
||||
_libhip.hipMemcpy.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int]
|
||||
def hipMemcpy(dst, src, count, direction):
|
||||
status = _libhip.hipMemcpy(dst, src, ctypes.c_size_t(count), direction)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipMemcpyAsync.restype = int
|
||||
_libhip.hipMemcpyAsync.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_void_p]
|
||||
def hipMemcpyAsync(dst, src, count, direction, stream):
|
||||
status = _libhip.hipMemcpyAsync(dst, src, ctypes.c_size_t(count), direction, stream)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipDeviceEnablePeerAccess.restype = int
|
||||
_libhip.hipDeviceEnablePeerAccess.argtypes = [ctypes.c_int, ctypes.c_uint]
|
||||
def hipDeviceEnablePeerAccess(peerDevice, flags):
|
||||
status = _libhip.hipDeviceEnablePeerAccess(peerDevice, flags)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipMemGetInfo.restype = int
|
||||
_libhip.hipMemGetInfo.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
||||
def hipMemGetInfo():
|
||||
free = ctypes.c_size_t()
|
||||
total = ctypes.c_size_t()
|
||||
status = _libhip.hipMemGetInfo(ctypes.byref(free), ctypes.byref(total))
|
||||
hipCheckStatus(status)
|
||||
return free.value, total.value
|
||||
|
||||
class hipIpcMemHandle_t(ctypes.Structure):
|
||||
_fields_ = [("reserved", ctypes.c_char * 64)]
|
||||
|
||||
_libhip.hipIpcGetMemHandle.restype = int
|
||||
_libhip.hipIpcGetMemHandle.argtypes = [ctypes.POINTER(hipIpcMemHandle_t), ctypes.c_void_p]
|
||||
def hipIpcGetMemHandle(ptr):
|
||||
handle = hipIpcMemHandle_t()
|
||||
status = _libhip.hipIpcGetMemHandle(ctypes.byref(handle), ptr)
|
||||
hipCheckStatus(status)
|
||||
return handle
|
||||
|
||||
_libhip.hipIpcOpenMemHandle.restype = int
|
||||
_libhip.hipIpcOpenMemHandle.argtypes = [ctypes.POINTER(ctypes.c_void_p), hipIpcMemHandle_t, ctypes.c_uint]
|
||||
def hipIpcOpenMemHandle(handle, flags):
|
||||
ptr = ctypes.c_void_p()
|
||||
status = _libhip.hipIpcOpenMemHandle(ctypes.byref(ptr), handle, flags)
|
||||
hipCheckStatus(status)
|
||||
return ptr.value
|
||||
|
||||
_libhip.hipIpcCloseMemHandle.restype = int
|
||||
_libhip.hipIpcCloseMemHandle.argtypes = [ctypes.c_void_p]
|
||||
def hipIpcCloseMemHandle(ptr):
|
||||
status = _libhip.hipIpcCloseMemHandle(ptr)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipSetDevice.restype = int
|
||||
_libhip.hipSetDevice.argtypes = [ctypes.c_int]
|
||||
def hipSetDevice(dev):
|
||||
status = _libhip.hipSetDevice(dev)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipGetDevice.restype = int
|
||||
_libhip.hipGetDevice.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
||||
def hipGetDevice():
|
||||
dev = ctypes.c_int()
|
||||
status = _libhip.hipGetDevice(ctypes.byref(dev))
|
||||
hipCheckStatus(status)
|
||||
return dev.value
|
||||
|
||||
_libhip.hipGetDeviceCount.restype = int
|
||||
_libhip.hipGetDeviceCount.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
||||
def hipGetDeviceCount():
|
||||
count = ctypes.c_int()
|
||||
status = _libhip.hipGetDeviceCount(ctypes.byref(count))
|
||||
hipCheckStatus(status)
|
||||
return count.value
|
||||
|
||||
class hipDeviceArch(ctypes.Structure):
|
||||
_fields_ = [
|
||||
# *32-bit Atomics*
|
||||
# 32-bit integer atomics for global memory.
|
||||
("hasGlobalInt32Atomics", ctypes.c_uint, 1),
|
||||
|
||||
# 32-bit float atomic exch for global memory.
|
||||
("hasGlobalFloatAtomicExch", ctypes.c_uint, 1),
|
||||
|
||||
# 32-bit integer atomics for shared memory.
|
||||
("hasSharedInt32Atomics", ctypes.c_uint, 1),
|
||||
|
||||
# 32-bit float atomic exch for shared memory.
|
||||
("hasSharedFloatAtomicExch", ctypes.c_uint, 1),
|
||||
|
||||
# 32-bit float atomic add in global and shared memory.
|
||||
("hasFloatAtomicAdd", ctypes.c_uint, 1),
|
||||
|
||||
# *64-bit Atomics*
|
||||
# 64-bit integer atomics for global memory.
|
||||
("hasGlobalInt64Atomics", ctypes.c_uint, 1),
|
||||
|
||||
# 64-bit integer atomics for shared memory.
|
||||
("hasSharedInt64Atomics", ctypes.c_uint, 1),
|
||||
|
||||
# *Doubles*
|
||||
# Double-precision floating point.
|
||||
("hasDoubles", ctypes.c_uint, 1),
|
||||
|
||||
# *Warp cross-lane operations*
|
||||
# Warp vote instructions (__any, __all).
|
||||
("hasWarpVote", ctypes.c_uint, 1),
|
||||
|
||||
# Warp ballot instructions (__ballot).
|
||||
("hasWarpBallot", ctypes.c_uint, 1),
|
||||
|
||||
# Warp shuffle operations. (__shfl_*).
|
||||
("hasWarpShuffle", ctypes.c_uint, 1),
|
||||
|
||||
# Funnel two words into one with shift&mask caps.
|
||||
("hasFunnelShift", ctypes.c_uint, 1),
|
||||
|
||||
# *Sync*
|
||||
# __threadfence_system.
|
||||
("hasThreadFenceSystem", ctypes.c_uint, 1),
|
||||
|
||||
# __syncthreads_count, syncthreads_and, syncthreads_or.
|
||||
("hasSyncThreadsExt", ctypes.c_uint, 1),
|
||||
|
||||
# *Misc*
|
||||
# Surface functions.
|
||||
("hasSurfaceFuncs", ctypes.c_uint, 1),
|
||||
|
||||
# Grid and group dims are 3D (rather than 2D).
|
||||
("has3dGrid", ctypes.c_uint, 1),
|
||||
|
||||
# Dynamic parallelism.
|
||||
("hasDynamicParallelism", ctypes.c_uint, 1),
|
||||
]
|
||||
|
||||
class hipDeviceProperties(ctypes.Structure):
|
||||
_fields_ = [
|
||||
# Device name
|
||||
("_name", ctypes.c_char * 256),
|
||||
|
||||
# Size of global memory region (in bytes)
|
||||
("totalGlobalMem", ctypes.c_size_t),
|
||||
|
||||
# Size of shared memory region (in bytes).
|
||||
("sharedMemPerBlock", ctypes.c_size_t),
|
||||
|
||||
# Registers per block.
|
||||
("regsPerBlock", ctypes.c_int),
|
||||
|
||||
# Warp size.
|
||||
("warpSize", ctypes.c_int),
|
||||
|
||||
# Max work items per work group or workgroup max size.
|
||||
("maxThreadsPerBlock", ctypes.c_int),
|
||||
|
||||
# Max number of threads in each dimension (XYZ) of a block.
|
||||
("maxThreadsDim", ctypes.c_int * 3),
|
||||
|
||||
# Max grid dimensions (XYZ).
|
||||
("maxGridSize", ctypes.c_int * 3),
|
||||
|
||||
# Max clock frequency of the multiProcessors in khz.
|
||||
("clockRate", ctypes.c_int),
|
||||
|
||||
# Max global memory clock frequency in khz.
|
||||
("memoryClockRate", ctypes.c_int),
|
||||
|
||||
# Global memory bus width in bits.
|
||||
("memoryBusWidth", ctypes.c_int),
|
||||
|
||||
# Size of shared memory region (in bytes).
|
||||
("totalConstMem", ctypes.c_size_t),
|
||||
|
||||
# Major compute capability. On HCC, this is an approximation and features may
|
||||
# differ from CUDA CC. See the arch feature flags for portable ways to query
|
||||
# feature caps.
|
||||
("major", ctypes.c_int),
|
||||
|
||||
# Minor compute capability. On HCC, this is an approximation and features may
|
||||
# differ from CUDA CC. See the arch feature flags for portable ways to query
|
||||
# feature caps.
|
||||
("minor", ctypes.c_int),
|
||||
|
||||
# Number of multi-processors (compute units).
|
||||
("multiProcessorCount", ctypes.c_int),
|
||||
|
||||
# L2 cache size.
|
||||
("l2CacheSize", ctypes.c_int),
|
||||
|
||||
# Maximum resident threads per multi-processor.
|
||||
("maxThreadsPerMultiProcessor", ctypes.c_int),
|
||||
|
||||
# Compute mode.
|
||||
("computeMode", ctypes.c_int),
|
||||
|
||||
# Frequency in khz of the timer used by the device-side "clock*"
|
||||
# instructions. New for HIP.
|
||||
("clockInstructionRate", ctypes.c_int),
|
||||
|
||||
# Architectural feature flags. New for HIP.
|
||||
("arch", hipDeviceArch),
|
||||
|
||||
# Device can possibly execute multiple kernels concurrently.
|
||||
("concurrentKernels", ctypes.c_int),
|
||||
|
||||
# PCI Domain ID
|
||||
("pciDomainID", ctypes.c_int),
|
||||
|
||||
# PCI Bus ID.
|
||||
("pciBusID", ctypes.c_int),
|
||||
|
||||
# PCI Device ID.
|
||||
("pciDeviceID", ctypes.c_int),
|
||||
|
||||
# Maximum Shared Memory Per Multiprocessor.
|
||||
("maxSharedMemoryPerMultiProcessor", ctypes.c_size_t),
|
||||
|
||||
# 1 if device is on a multi-GPU board, 0 if not.
|
||||
("isMultiGpuBoard", ctypes.c_int),
|
||||
|
||||
# Check whether HIP can map host memory
|
||||
("canMapHostMemory", ctypes.c_int),
|
||||
|
||||
# DEPRECATED: use gcnArchName instead
|
||||
("gcnArch", ctypes.c_int),
|
||||
|
||||
# AMD GCN Arch Name.
|
||||
("_gcnArchName", ctypes.c_char * 256),
|
||||
|
||||
# APU vs dGPU
|
||||
("integrated", ctypes.c_int),
|
||||
|
||||
# HIP device supports cooperative launch
|
||||
("cooperativeLaunch", ctypes.c_int),
|
||||
|
||||
# HIP device supports cooperative launch on multiple devices
|
||||
("cooperativeMultiDeviceLaunch", ctypes.c_int),
|
||||
|
||||
# Maximum size for 1D textures bound to linear memory
|
||||
("maxTexture1DLinear", ctypes.c_int),
|
||||
|
||||
# Maximum number of elements in 1D images
|
||||
("maxTexture1D", ctypes.c_int),
|
||||
|
||||
# Maximum dimensions (width, height) of 2D images, in image elements
|
||||
("maxTexture2D", ctypes.c_int * 2),
|
||||
|
||||
# Maximum dimensions (width, height, depth) of 3D images, in image elements
|
||||
("maxTexture3D", ctypes.c_int * 3),
|
||||
|
||||
# Addres of HDP_MEM_COHERENCY_FLUSH_CNTL register
|
||||
("hdpMemFlushCntl", ctypes.POINTER(ctypes.c_uint)),
|
||||
|
||||
# Addres of HDP_REG_COHERENCY_FLUSH_CNTL register
|
||||
("hdpRegFlushCntl", ctypes.POINTER(ctypes.c_uint)),
|
||||
|
||||
# Maximum pitch in bytes allowed by memory copies
|
||||
("memPitch", ctypes.c_size_t),
|
||||
|
||||
# Alignment requirement for textures
|
||||
("textureAlignment", ctypes.c_size_t),
|
||||
|
||||
# Pitch alignment requirement for texture references bound to pitched memory
|
||||
("texturePitchAlignment", ctypes.c_size_t),
|
||||
|
||||
# Run time limit for kernels executed on the device
|
||||
("kernelExecTimeoutEnabled", ctypes.c_int),
|
||||
|
||||
# Device has ECC support enabled
|
||||
("ECCEnabled", ctypes.c_int),
|
||||
|
||||
# 1:If device is Tesla device using TCC driver, else 0
|
||||
("tccDriver", ctypes.c_int),
|
||||
|
||||
# HIP device supports cooperative launch on multiple
|
||||
# devices with unmatched functions
|
||||
("cooperativeMultiDeviceUnmatchedFunc", ctypes.c_int),
|
||||
|
||||
# HIP device supports cooperative launch on multiple
|
||||
# devices with unmatched grid dimensions
|
||||
("cooperativeMultiDeviceUnmatchedGridDim", ctypes.c_int),
|
||||
|
||||
# HIP device supports cooperative launch on multiple
|
||||
# devices with unmatched block dimensions
|
||||
("cooperativeMultiDeviceUnmatchedBlockDim", ctypes.c_int),
|
||||
|
||||
# HIP device supports cooperative launch on multiple
|
||||
# devices with unmatched shared memories
|
||||
("cooperativeMultiDeviceUnmatchedSharedMem", ctypes.c_int),
|
||||
|
||||
# 1: if it is a large PCI bar device, else 0
|
||||
("isLargeBar", ctypes.c_int),
|
||||
|
||||
# Revision of the GPU in this device
|
||||
("asicRevision", ctypes.c_int),
|
||||
|
||||
# Device supports allocating managed memory on this system
|
||||
("managedMemory", ctypes.c_int),
|
||||
|
||||
# Host can directly access managed memory on the device without migration
|
||||
("directManagedMemAccessFromHost", ctypes.c_int),
|
||||
|
||||
# Device can coherently access managed memory concurrently with the CPU
|
||||
("concurrentManagedAccess", ctypes.c_int),
|
||||
|
||||
# Device supports coherently accessing pageable memory
|
||||
# without calling hipHostRegister on it
|
||||
("pageableMemoryAccess", ctypes.c_int),
|
||||
|
||||
# Device accesses pageable memory via the host"s page tables
|
||||
("pageableMemoryAccessUsesHostPageTables", ctypes.c_int),
|
||||
]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name.decode("utf-8")
|
||||
|
||||
@property
|
||||
def gcnArchName(self):
|
||||
return self._gcnArchName.decode("utf-8")
|
||||
|
||||
_libhip.hipGetDeviceProperties.restype = int
|
||||
_libhip.hipGetDeviceProperties.argtypes = [ctypes.POINTER(hipDeviceProperties), ctypes.c_int]
|
||||
def hipGetDeviceProperties(deviceId: int):
|
||||
device_properties = hipDeviceProperties()
|
||||
status = _libhip.hipGetDeviceProperties(ctypes.pointer(device_properties), deviceId)
|
||||
hipCheckStatus(status)
|
||||
return device_properties
|
||||
|
||||
_libhip.hipModuleLoadData.restype = int
|
||||
_libhip.hipModuleLoadData.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p]
|
||||
def hipModuleLoadData(data):
|
||||
module = ctypes.c_void_p()
|
||||
status = _libhip.hipModuleLoadData(ctypes.byref(module), data)
|
||||
hipCheckStatus(status)
|
||||
return module
|
||||
|
||||
_libhip.hipModuleGetFunction.restype = int
|
||||
_libhip.hipModuleGetFunction.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
|
||||
def hipModuleGetFunction(module, func_name):
|
||||
kernel = ctypes.c_void_p()
|
||||
status = _libhip.hipModuleGetFunction(ctypes.byref(kernel), module, func_name.encode("utf-8"))
|
||||
hipCheckStatus(status)
|
||||
return kernel
|
||||
|
||||
_libhip.hipModuleUnload.restype = int
|
||||
_libhip.hipModuleUnload.argtypes = [ctypes.c_void_p]
|
||||
def hipModuleUnload(module):
|
||||
status = _libhip.hipModuleUnload(module)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhip.hipModuleLaunchKernel.restype = int
|
||||
_libhip.hipModuleLaunchKernel.argtypes = [ctypes.c_void_p,
|
||||
ctypes.c_uint, ctypes.c_uint, ctypes.c_uint, # block dim
|
||||
ctypes.c_uint, ctypes.c_uint, ctypes.c_uint, # thread dim
|
||||
ctypes.c_uint, ctypes.c_void_p,
|
||||
ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_void_p)]
|
||||
def hipModuleLaunchKernel(kernel, bx, by, bz, tx, ty, tz, shared, stream, struct):
|
||||
c_bx, c_by, c_bz = ctypes.c_uint(bx), ctypes.c_uint(by), ctypes.c_uint(bz)
|
||||
c_tx, c_ty, c_tz = ctypes.c_uint(tx), ctypes.c_uint(ty), ctypes.c_uint(tz)
|
||||
c_shared = ctypes.c_uint(shared)
|
||||
|
||||
param_buffer_ptr, param_buffer_size, param_buffer_end = ctypes.c_void_p(1), ctypes.c_void_p(2), ctypes.c_void_p(3)
|
||||
size = ctypes.c_size_t(ctypes.sizeof(struct))
|
||||
p_size, p_struct = ctypes.c_void_p(ctypes.addressof(size)), ctypes.c_void_p(ctypes.addressof(struct))
|
||||
config = (ctypes.c_void_p * 5)(param_buffer_ptr, p_struct, param_buffer_size, p_size, param_buffer_end)
|
||||
|
||||
status = _libhip.hipModuleLaunchKernel(kernel, c_bx, c_by, c_bz, c_tx, c_ty, c_tz, c_shared, stream, None, config)
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhiprtc.hiprtcCreateProgram.restype = int
|
||||
_libhiprtc.hiprtcCreateProgram.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char),
|
||||
ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.POINTER(ctypes.c_char_p)]
|
||||
def hiprtcCreateProgram(source, name, header_names, header_sources):
|
||||
c_header_names, c_header_sources = (ctypes.c_char_p * len(header_names))(), (ctypes.c_char_p * len(header_sources))()
|
||||
c_header_names[:], c_header_sources[:] = [h.encode("utf-8") for h in header_names], [h.encode("utf-8") for h in header_sources]
|
||||
|
||||
prog = ctypes.c_void_p()
|
||||
status = _libhiprtc.hiprtcCreateProgram(ctypes.byref(prog), source.encode("utf-8"), name.encode("utf-8"), len(header_names), c_header_sources, c_header_names)
|
||||
hipCheckStatus(status)
|
||||
return prog
|
||||
|
||||
_libhiprtc.hiprtcDestroyProgram.restype = int
|
||||
_libhiprtc.hiprtcDestroyProgram.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
|
||||
def hiprtcDestroyProgram(prog):
|
||||
status = _libhiprtc.hiprtcDestroyProgram(ctypes.byref(prog))
|
||||
hipCheckStatus(status)
|
||||
|
||||
|
||||
_libhiprtc.hiprtcGetProgramLogSize.restype = int
|
||||
_libhiprtc.hiprtcGetProgramLogSize.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_size_t)]
|
||||
_libhiprtc.hiprtcGetProgramLog.restype = int
|
||||
_libhiprtc.hiprtcGetProgramLog.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
def hiprtcGetProgramLog(prog):
|
||||
logsz = ctypes.c_size_t()
|
||||
status = _libhiprtc.hiprtcGetProgramLogSize(prog, logsz)
|
||||
hipCheckStatus(status)
|
||||
logstr = ctypes.create_string_buffer(logsz.value)
|
||||
status = _libhiprtc.hiprtcGetProgramLog(prog, logstr)
|
||||
hipCheckStatus(status)
|
||||
return logstr.value.decode()
|
||||
|
||||
_libhiprtc.hiprtcCompileProgram.restype = int
|
||||
_libhiprtc.hiprtcCompileProgram.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p)]
|
||||
def hiprtcCompileProgram(prog, options):
|
||||
c_options = (ctypes.c_char_p * len(options))()
|
||||
c_options[:] = [o.encode("utf-8") for o in options]
|
||||
|
||||
status = _libhiprtc.hiprtcCompileProgram(prog, len(options), c_options)
|
||||
if status == 6: print(hiprtcGetProgramLog(prog))
|
||||
hipCheckStatus(status)
|
||||
|
||||
_libhiprtc.hiprtcGetCodeSize.restype = int
|
||||
_libhiprtc.hiprtcGetCodeSize.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_size_t)]
|
||||
_libhiprtc.hiprtcGetCode.restype = int
|
||||
_libhiprtc.hiprtcGetCode.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
|
||||
def hiprtcGetCode(prog):
|
||||
code_size = ctypes.c_size_t()
|
||||
status = _libhiprtc.hiprtcGetCodeSize(prog, ctypes.byref(code_size))
|
||||
hipCheckStatus(status)
|
||||
e_code = ("0" * code_size.value).encode("utf-8")
|
||||
status = _libhiprtc.hiprtcGetCode(prog, e_code)
|
||||
hipCheckStatus(status)
|
||||
return e_code
|
||||
except:
|
||||
if DEBUG >= 1: print("WARNING: libamdhip64.so or libhiprtc.so not found. HIP support will not work.")
|
||||
5
setup.py
5
setup.py
@@ -19,15 +19,14 @@ setup(name='tinygrad',
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
],
|
||||
install_requires=["numpy", "tqdm", "pyopencl",
|
||||
install_requires=["numpy", "tqdm", "pyopencl", "gpuctypes",
|
||||
"pyobjc-framework-Metal; platform_system=='Darwin'",
|
||||
"pyobjc-framework-libdispatch; platform_system=='Darwin'"],
|
||||
python_requires='>=3.8',
|
||||
extras_require={
|
||||
'llvm': ["llvmlite"],
|
||||
'cuda': ["pycuda"],
|
||||
'arm': ["unicorn"],
|
||||
'triton': ["triton-nightly>=2.1.0.dev20231014192330", "pycuda"],
|
||||
'triton': ["triton-nightly>=2.1.0.dev20231014192330"],
|
||||
'webgpu': ["wgpu>=v0.12.0"],
|
||||
'linting': [
|
||||
"flake8",
|
||||
|
||||
@@ -57,6 +57,14 @@ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+str
|
||||
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
||||
def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
|
||||
def from_mv(mv, to_type=ctypes.c_char): return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type))
|
||||
def to_char_p_p(options: List[ctypes._CData], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(o, ctypes.POINTER(to_type)) for o in options])
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
||||
class CStruct(ctypes.Structure):
|
||||
_pack_, _fields_ = 1, fields
|
||||
return CStruct
|
||||
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
||||
def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1]
|
||||
|
||||
class Context(contextlib.ContextDecorator):
|
||||
stack: ClassVar[List[dict[str, int]]] = [{}]
|
||||
@@ -257,7 +265,14 @@ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=n
|
||||
pathlib.Path(f.name).rename(fp)
|
||||
return fp
|
||||
|
||||
# *** pretty PTX printer
|
||||
# *** Exec helpers
|
||||
|
||||
def cpu_time_execution(cb, enable):
|
||||
if enable: st = time.perf_counter()
|
||||
cb()
|
||||
if enable: return time.perf_counter()-st
|
||||
|
||||
# *** Helpers for CUDA-like APIs.
|
||||
|
||||
def pretty_ptx(s):
|
||||
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
|
||||
@@ -268,3 +283,25 @@ def pretty_ptx(s):
|
||||
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
|
||||
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
|
||||
return s
|
||||
|
||||
def compile_cuda_style(prg, compile_options, prog_t, create_prog, compile_prog, get_code, get_code_size, get_log, get_log_size, check) -> bytes:
|
||||
check(create_prog(ctypes.byref(prog := prog_t()), prg.encode(), "<null>".encode(), 0, None, None))
|
||||
status = compile_prog(prog, len(compile_options), to_char_p_p([ctypes.create_string_buffer(o.encode()) for o in compile_options]))
|
||||
|
||||
if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, get_log_size, get_log, check).decode()}")
|
||||
return get_bytes(prog, get_code_size, get_code, check)
|
||||
|
||||
def encode_args_cuda_style(args, device_ptr_t, marks) -> Tuple[ctypes.Array, ctypes.Structure]:
|
||||
c_args = init_c_struct_t(tuple([(f'f{i}', device_ptr_t if not isinstance(x, int) else ctypes.c_int) for i,x in enumerate(args)]))(*args)
|
||||
return (ctypes.c_void_p * 5)(ctypes.c_void_p(marks[0]), ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), ctypes.c_void_p(marks[1]), ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(marks[2])), c_args
|
||||
|
||||
def time_execution_cuda_style(cb, ev_t, evcreate, evrecord, evsync, evdestroy, evtime, enable=False) -> Optional[float]:
|
||||
if not enable: return cb()
|
||||
evs = [init_c_var(ev_t(), lambda x: evcreate(ctypes.byref(x), 0)) for _ in range(2)]
|
||||
evrecord(evs[0], None)
|
||||
cb()
|
||||
evrecord(evs[1], None)
|
||||
evsync(evs[1])
|
||||
evtime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1])
|
||||
for ev in evs: evdestroy(ev)
|
||||
return ret.value * 1e-3
|
||||
|
||||
@@ -2,7 +2,7 @@ import functools
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
|
||||
class CUDALanguage(CStyleLanguage):
|
||||
kernel_prefix = "__global__ "
|
||||
kernel_prefix = "#define INFINITY (__int_as_float(0x7f800000))\n#define NAN (__int_as_float(0x7fffffff))\nextern \"C\" __global__ "
|
||||
smem_prefix = "__shared__ "
|
||||
smem_prefix_for_cast = False
|
||||
arg_int_prefix = "const int"
|
||||
|
||||
72
tinygrad/runtime/graph/cuda.py
Normal file
72
tinygrad/runtime/graph/cuda.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import ctypes
|
||||
from typing import Any, Optional, Tuple, Dict, List, cast
|
||||
import gpuctypes.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, encode_args_cuda_style
|
||||
from tinygrad.device import CompiledASTRunner, update_stats, Buffer
|
||||
from tinygrad.runtime.ops_cuda import check, cu_time_execution
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException
|
||||
|
||||
class CUDAGraph:
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
|
||||
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
||||
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
|
||||
self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
|
||||
self.updatable_nodes: Dict[int, Tuple[Any, Any, Any]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params)
|
||||
|
||||
self.graph = self.graph_create()
|
||||
graph_node: Optional[ctypes._CData] = None
|
||||
|
||||
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
|
||||
|
||||
c_deps = (type(graph_node)*1)(*(graph_node,)) if graph_node is not None else None
|
||||
c_kernel_input_config, c_input_params = encode_args_cuda_style([cast(Buffer, x)._buf for x in ji.rawbufs] + [var_vals[x] for x in prg.vars], *self.encode_args_info())
|
||||
c_node_params = self.build_kernel_node_params(prg, *cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)), c_kernel_input_config)
|
||||
graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params)
|
||||
|
||||
if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs:
|
||||
self.updatable_nodes[j] = (graph_node, c_node_params, c_input_params)
|
||||
|
||||
self.instance = self.graph_instantiate(self.graph)
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
# Update rawbuffers in the c_input_params struct.
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
|
||||
|
||||
# Update var_vals in the c_input_params struct.
|
||||
for j in self.jc_idxs_with_updatable_var_vals:
|
||||
for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars):
|
||||
setattr(self.updatable_nodes[j][2], f'f{len(self.jit_cache[j].rawbufs) + i}', var_vals[v])
|
||||
|
||||
# Update launch dims in the c_node_params struct.
|
||||
for j in self.jc_idxs_with_updatable_launch_dims:
|
||||
self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals))
|
||||
|
||||
# Update graph nodes with the updated structs.
|
||||
for node, c_node_params, _ in self.updatable_nodes.values():
|
||||
self.graph_exec_kernel_node_set_params(self.instance, node, ctypes.byref(c_node_params))
|
||||
|
||||
et = self.graph_launch(self.instance, None, wait=wait)
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache))
|
||||
return et
|
||||
|
||||
def __del__(self):
|
||||
check(cuda.cuGraphDestroy(self.graph))
|
||||
check(cuda.cuGraphExecDestroy(self.instance))
|
||||
|
||||
def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0))
|
||||
def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
||||
def graph_instantiate(self, graph): return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0)))
|
||||
def graph_add_kernel_node(self, graph, c_deps, c_node_params): return init_c_var(cuda.CUgraphNode(), lambda x: check(cuda.cuGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_node_params))))
|
||||
def graph_launch(self, *args, wait=False): return cu_time_execution(lambda: check(cuda.cuGraphLaunch(*args)), enable=wait)
|
||||
def graph_exec_kernel_node_set_params(self, *args): return check(cuda.cuGraphExecKernelNodeSetParams(*args))
|
||||
def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config): return cuda.CUDA_KERNEL_NODE_PARAMS(prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config)
|
||||
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]): node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
|
||||
20
tinygrad/runtime/graph/hip.py
Normal file
20
tinygrad/runtime/graph/hip.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import ctypes
|
||||
from typing import Tuple
|
||||
import gpuctypes.hip as hip
|
||||
from tinygrad.helpers import init_c_var
|
||||
from tinygrad.runtime.ops_hip import check, hip_time_execution
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
|
||||
class HIPGraph(CUDAGraph):
|
||||
def __del__(self):
|
||||
check(hip.hipGraphDestroy(self.graph))
|
||||
check(hip.hipGraphExecDestroy(self.instance))
|
||||
|
||||
def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3))
|
||||
def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0)))
|
||||
def graph_instantiate(self, graph): return init_c_var(hip.hipGraphExec_t(), lambda x: check(hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0)))
|
||||
def graph_add_kernel_node(self, graph, c_deps, c_params): return init_c_var(hip.hipGraphNode_t(), lambda x: check(hip.hipGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_params))))
|
||||
def graph_launch(self, *args, wait=False): return hip_time_execution(lambda: check(hip.hipGraphLaunch(*args)), enable=wait)
|
||||
def graph_exec_kernel_node_set_params(self, *args): return check(hip.hipGraphExecKernelNodeSetParams(*args))
|
||||
def build_kernel_node_params(self, prg, global_size, local_size, c_config): return hip.hipKernelNodeParams(hip.dim3(*local_size), c_config, ctypes.cast(prg.clprg.prg, ctypes.c_void_p), hip.dim3(*global_size), None, 0)
|
||||
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]): node.blockDim.x, node.blockDim.y, node.blockDim.z, node.gridDim.x, node.gridDim.y, node.gridDim.z = *local_size, *global_size
|
||||
@@ -1,7 +1,7 @@
|
||||
import time, ctypes, subprocess, platform, functools, pathlib, tempfile
|
||||
import ctypes, subprocess, platform, functools, pathlib, tempfile
|
||||
from typing import Any
|
||||
from tinygrad.device import Compiled, MallocAllocator
|
||||
from tinygrad.helpers import diskcache
|
||||
from tinygrad.helpers import diskcache, cpu_time_execution
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
|
||||
@@ -26,10 +26,7 @@ class ClangProgram:
|
||||
pathlib.Path(cached_file_path.name).write_bytes(prg)
|
||||
self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name]
|
||||
|
||||
def __call__(self, *args, wait=False):
|
||||
if wait: st = time.perf_counter()
|
||||
self.fxn(*args)
|
||||
if wait: return time.perf_counter()-st
|
||||
def __call__(self, *args, wait=False): return cpu_time_execution(lambda: self.fxn(*args), enable=wait)
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict", arg_int_prefix="const int"))
|
||||
ClangDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram)
|
||||
|
||||
@@ -1,80 +1,71 @@
|
||||
import subprocess, time, hashlib, tempfile
|
||||
import subprocess, hashlib, tempfile, ctypes, ctypes.util
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
import numpy as np
|
||||
from pycuda.compiler import compile as cuda_compile
|
||||
from tinygrad.helpers import DEBUG, getenv, pretty_ptx, diskcache
|
||||
from typing import Tuple, Optional
|
||||
import gpuctypes.cuda as cuda
|
||||
from tinygrad.helpers import DEBUG, getenv, diskcache, from_mv, init_c_var, pretty_ptx, cpu_time_execution, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style
|
||||
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cuda import CUDARenderer
|
||||
|
||||
def arch(): return "sm_" + "".join([str(x) for x in pycuda.driver.Context.get_device().compute_capability()])
|
||||
CUDA_INCLUDE_PATH = getenv("CUDA_INCLUDE_PATH", default="-I/usr/local/cuda/include")
|
||||
CUDACPU = getenv("CUDACPU") == 1
|
||||
|
||||
if CUDACPU:
|
||||
import ctypes, ctypes.util
|
||||
lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
|
||||
lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
|
||||
class cuda:
|
||||
class module:
|
||||
def __init__(self, src): self.src = src
|
||||
def get_function(self, _): return self
|
||||
def __call__(self, *args, block, grid, shared): lib.ptx_run(self.src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), *block, *grid, shared)
|
||||
module_from_buffer = lambda src: cuda.module(src) # pylint: disable=unnecessary-lambda # noqa: E731
|
||||
class Event:
|
||||
def __init__(self): pass
|
||||
def record(self): self.start = time.perf_counter()
|
||||
def time_till(self, other): return other.start - self.start
|
||||
def synchronize(self): pass
|
||||
class Context:
|
||||
synchronize = lambda:0 # noqa: E731
|
||||
CompileError = Exception
|
||||
class context:
|
||||
class device:
|
||||
compute_capability = lambda: (3,5) # pylint: disable=unnecessary-lambda # noqa: E731
|
||||
get_device = lambda: context.device # pylint: disable=unnecessary-lambda # noqa: E731
|
||||
import pycuda.driver
|
||||
pycuda.driver.Context = context
|
||||
else:
|
||||
import pycuda.autoprimaryctx # pylint: disable=unused-import # noqa: F401
|
||||
import pycuda.driver as cuda # type: ignore
|
||||
class CUDAAllocator(LRUAllocator):
|
||||
def _alloc(self, size, dtype):
|
||||
if size == 0: return None
|
||||
return cuda.mem_alloc(size * dtype.itemsize) # type: ignore
|
||||
def copyin(self, dest, src:memoryview): cuda.memcpy_htod_async(dest, src) # type: ignore
|
||||
def copyout(self, dest:memoryview, src): cuda.memcpy_dtoh(dest, src) # type: ignore
|
||||
gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
|
||||
gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
|
||||
cuda.cuLaunchKernel = lambda src, gx, gy, gz, lx, ly, lz, shared, stream, unused_extra, args: gpuocelot_lib.ptx_run(src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), lx, ly, lz, gx, gy, gz, shared)
|
||||
|
||||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}")
|
||||
|
||||
def cu_time_execution(cb, enable=False) -> Optional[float]: return time_execution_cuda_style(cb, cuda.CUevent, cuda.cuEventCreate, cuda.cuEventRecord, cuda.cuEventSynchronize, cuda.cuEventDestroy_v2, cuda.cuEventElapsedTime, enable=enable) if not CUDACPU else cpu_time_execution(cb, enable=enable)
|
||||
|
||||
@diskcache
|
||||
def compile_cuda(prg) -> bytes: return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets'])
|
||||
def compile_cuda(prg) -> bytes: return compile_cuda_style(prg, [f'--gpu-architecture={CUDADevice.default_arch_name}', CUDA_INCLUDE_PATH], cuda.nvrtcProgram, cuda.nvrtcCreateProgram, cuda.nvrtcCompileProgram, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check)
|
||||
|
||||
class CUDAProgram:
|
||||
def __init__(self, name:str, _prg:bytes, bufs:int, vars:int=0):
|
||||
prg = _prg.decode('utf-8')
|
||||
if DEBUG >= 5: print(pretty_ptx(prg))
|
||||
def __init__(self, name:str, prg:bytes, bufs:int, vars:int=0):
|
||||
if DEBUG >= 5: print(pretty_ptx(prg.decode('utf-8')))
|
||||
if DEBUG >= 6:
|
||||
try:
|
||||
fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(prg.encode('utf-8')).hexdigest()}").as_posix()
|
||||
with open(fn + ".ptx", "wb") as f: f.write(prg.encode('utf-8'))
|
||||
subprocess.run(["ptxas", f"-arch={arch()}", "-o", fn, fn+".ptx"], check=True)
|
||||
fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(prg).hexdigest()}").as_posix()
|
||||
with open(fn + ".ptx", "wb") as f: f.write(prg)
|
||||
subprocess.run(["ptxas", f"-arch={CUDADevice.default_arch_name}", "-o", fn, fn+".ptx"], check=True)
|
||||
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
|
||||
except Exception as e: print("failed to generate SASS", str(e))
|
||||
# TODO: name is wrong, so we get it from the ptx using hacks
|
||||
self.prg = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0])
|
||||
|
||||
if not CUDACPU:
|
||||
self.module = init_c_var(cuda.CUmodule(), lambda x: check(cuda.cuModuleLoadData(ctypes.byref(x), prg)))
|
||||
check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8")))
|
||||
self.prg = prg
|
||||
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], shared:int=0, wait=False):
|
||||
if wait:
|
||||
start, end = cuda.Event(), cuda.Event()
|
||||
start.record()
|
||||
self.prg(*[np.int32(x) if (isinstance(x, int) and not CUDACPU) else x for x in args], block=tuple(local_size), grid=tuple(global_size), shared=shared)
|
||||
if wait:
|
||||
end.record()
|
||||
end.synchronize()
|
||||
return start.time_till(end)*1e-3
|
||||
def __del__(self):
|
||||
if not CUDACPU: check(cuda.cuModuleUnload(self.module))
|
||||
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False):
|
||||
c_kernel_input_config = encode_args_cuda_style(args, cuda.CUdeviceptr_v2, (1,2,0))[0] if not CUDACPU else args
|
||||
return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait)
|
||||
|
||||
class CUDAAllocator(LRUAllocator):
|
||||
def _alloc(self, size, dtype):
|
||||
if size == 0: return None
|
||||
return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size * dtype.itemsize)))
|
||||
def _free(self, opaque): check(cuda.cuMemFree_v2(opaque))
|
||||
def copyin(self, dest, src:memoryview): check(cuda.cuMemcpyHtoD_v2(dest, from_mv(src), len(src), None))
|
||||
def copyout(self, dest:memoryview, src): check(cuda.cuMemcpyDtoH_v2(from_mv(dest), src, len(dest)))
|
||||
|
||||
class CUDADevice(Compiled):
|
||||
default_arch_name = "sm_35"
|
||||
def __init__(self, device:str):
|
||||
super().__init__(MallocAllocator if CUDACPU else CUDAAllocator(),
|
||||
self.device = int(device.split(":")[1]) if ":" in device else 0
|
||||
if not CUDACPU:
|
||||
check(cuda.cuInit(0))
|
||||
check(cuda.cuDeviceGet(ctypes.byref(device := cuda.CUdevice()), self.device))
|
||||
check(cuda.cuCtxCreate_v2(ctypes.byref(_ := cuda.CUcontext()), 0, device))
|
||||
check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), self.device))
|
||||
if self.device == 0: CUDADevice.default_arch_name = f"sm_{major.value}{minor.value}"
|
||||
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
super().__init__(CUDAAllocator() if not CUDACPU else MallocAllocator,
|
||||
LinearizerOptions(supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]),
|
||||
CUDARenderer, compile_cuda, CUDAProgram)
|
||||
def synchronize(self): return cuda.Context.synchronize()
|
||||
CUDARenderer, compile_cuda, CUDAProgram, graph=CUDAGraph if not CUDACPU else None)
|
||||
def synchronize(self): return check(cuda.cuCtxSynchronize()) if not CUDACPU else None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ctypes, functools
|
||||
import extra.hip_wrapper as hip
|
||||
from typing import Tuple, cast, Callable, TypeVar
|
||||
from tinygrad.helpers import DEBUG, DType, getenv, diskcache, from_mv
|
||||
from typing import Tuple, TypeVar
|
||||
import gpuctypes.hip as hip
|
||||
from tinygrad.helpers import DEBUG, DType, getenv, diskcache, from_mv, init_c_var, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style
|
||||
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator
|
||||
from tinygrad.renderer.hip import HIPRenderer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
@@ -14,48 +14,34 @@ if DEBUG >= 6:
|
||||
# The default HIP stream is used for everything.
|
||||
MOCKHIP = getenv("MOCKHIP") # for CI. don't run kernels, only check if they compile
|
||||
|
||||
@diskcache
|
||||
def compile_hip(prg) -> bytes:
|
||||
prog = hip.hiprtcCreateProgram(prg, "<null>", [], [])
|
||||
hip.hiprtcCompileProgram(prog, [f'--offload-arch={HIPDevice.default_arch_name}'])
|
||||
return hip.hiprtcGetCode(prog)
|
||||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
|
||||
|
||||
def time_execution(cb, enable=False):
|
||||
if enable:
|
||||
start, end = hip.hipEventCreate(), hip.hipEventCreate()
|
||||
hip.hipEventRecord(start)
|
||||
cb()
|
||||
if enable:
|
||||
hip.hipEventRecord(end)
|
||||
hip.hipEventSynchronize(end)
|
||||
ret = hip.hipEventElapsedTime(start, end)*1e-3
|
||||
hip.hipEventDestroy(start)
|
||||
hip.hipEventDestroy(end)
|
||||
return ret
|
||||
def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable)
|
||||
|
||||
@diskcache
|
||||
def compile_hip(prg) -> bytes: return compile_cuda_style(prg, [f'--offload-arch={HIPDevice.default_arch_name}'], hip.hiprtcProgram, hip.hiprtcCreateProgram, hip.hiprtcCompileProgram, hip.hiprtcGetCode, hip.hiprtcGetCodeSize, hip.hiprtcGetProgramLog, hip.hiprtcGetProgramLogSize, check)
|
||||
|
||||
class HIPProgram:
|
||||
def __init__(self, device:int, name:str, prg:bytes, bufs:int, vars:int=0):
|
||||
self.device, self.c_struct_t = device, None
|
||||
self.device = device
|
||||
|
||||
if DEBUG >= 6:
|
||||
asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg))
|
||||
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
|
||||
|
||||
if MOCKHIP: return
|
||||
hip.hipSetDevice(self.device)
|
||||
self.module = hip.hipModuleLoadData(prg)
|
||||
self.prg = hip.hipModuleGetFunction(self.module, name)
|
||||
self.c_struct_t = hip.getCStructForType([ctypes.c_void_p]*bufs + [ctypes.c_int]*vars)
|
||||
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False):
|
||||
if MOCKHIP: return
|
||||
hip.hipSetDevice(self.device)
|
||||
c_params = cast(Callable, self.c_struct_t)(*args)
|
||||
return time_execution(lambda: hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, 0, c_params), enable=wait)
|
||||
check(hip.hipSetDevice(self.device))
|
||||
self.module = init_c_var(hip.hipModule_t(), lambda x: check(hip.hipModuleLoadData(ctypes.byref(x), prg)))
|
||||
self.prg = init_c_var(hip.hipFunction_t(), lambda x: check(hip.hipModuleGetFunction(ctypes.byref(x), self.module, name.encode("utf-8"))))
|
||||
|
||||
def __del__(self):
|
||||
if MOCKHIP: return
|
||||
hip.hipModuleUnload(self.module)
|
||||
if not MOCKHIP: check(hip.hipModuleUnload(self.module))
|
||||
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False):
|
||||
if MOCKHIP: return float("inf")
|
||||
check(hip.hipSetDevice(self.device))
|
||||
return hip_time_execution(lambda: check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, encode_args_cuda_style(args, hip.hipDeviceptr_t, marks=(1,2,3))[0])), enable=wait)
|
||||
|
||||
T = TypeVar("T")
|
||||
class HIPAllocator(LRUAllocator):
|
||||
@@ -64,23 +50,25 @@ class HIPAllocator(LRUAllocator):
|
||||
super().__init__()
|
||||
def _alloc(self, size: int, dtype: DType):
|
||||
if size == 0: return None
|
||||
hip.hipSetDevice(self.device)
|
||||
return hip.hipMalloc(size * dtype.itemsize)
|
||||
def _free(self, opaque:T): hip.hipFree(opaque)
|
||||
check(hip.hipSetDevice(self.device))
|
||||
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size * dtype.itemsize)))
|
||||
def _free(self, opaque:T): check(hip.hipFree(opaque))
|
||||
def copyin(self, dest:T, src: memoryview):
|
||||
hip.hipSetDevice(self.device)
|
||||
hip.hipMemcpyAsync(dest, from_mv(src), len(src), hip.hipMemcpyHostToDevice, 0)
|
||||
check(hip.hipSetDevice(self.device))
|
||||
check(hip.hipMemcpyAsync(dest, from_mv(src), len(src), hip.hipMemcpyHostToDevice, None))
|
||||
def copyout(self, dest:memoryview, src:T):
|
||||
hip.hipSetDevice(self.device)
|
||||
hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost)
|
||||
check(hip.hipSetDevice(self.device))
|
||||
check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
|
||||
def transfer(self, dest:T, src:T, sz:int):
|
||||
hip.hipSetDevice(self.device)
|
||||
hip.hipMemcpy(dest, src, sz, hip.hipMemcpyDeviceToDevice)
|
||||
check(hip.hipSetDevice(self.device))
|
||||
check(hip.hipMemcpy(dest, src, sz, hip.hipMemcpyDeviceToDevice))
|
||||
|
||||
class HIPDevice(Compiled):
|
||||
default_arch_name = "gfx1100"
|
||||
def __init__(self, device:str):
|
||||
self.device = int(device.split(":")[1]) if ":" in device else 0
|
||||
if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = hip.hipGetDeviceProperties(self.device).gcnArchName
|
||||
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self.device), LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, functools.partial(HIPProgram, self.device))
|
||||
if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode()
|
||||
|
||||
from tinygrad.runtime.graph.hip import HIPGraph
|
||||
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self.device), LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, functools.partial(HIPProgram, self.device), HIPGraph)
|
||||
def synchronize(self): hip.hipDeviceSynchronize()
|
||||
@@ -1,7 +1,7 @@
|
||||
import time, ctypes
|
||||
import ctypes
|
||||
from typing import ClassVar
|
||||
from tinygrad.device import Compiled, MallocAllocator
|
||||
from tinygrad.helpers import getenv, DEBUG, diskcache
|
||||
from tinygrad.helpers import getenv, DEBUG, diskcache, cpu_time_execution
|
||||
from ctypes import CFUNCTYPE
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.llvmir import uops_to_llvm_ir
|
||||
@@ -59,9 +59,6 @@ class LLVMProgram:
|
||||
self.fxn = LLVM.engine.get_function_address(name)
|
||||
self.cfunc = CFUNCTYPE(ctypes.c_int, *([ctypes.c_void_p]*bufs), *([ctypes.c_int]*vars))(self.fxn)
|
||||
|
||||
def __call__(self, *bufs, wait=False):
|
||||
if wait: st = time.perf_counter()
|
||||
self.cfunc(*bufs)
|
||||
if wait: return time.perf_counter()-st
|
||||
def __call__(self, *bufs, wait=False): return cpu_time_execution(lambda: self.cfunc(*bufs), enable=wait)
|
||||
|
||||
LLVMDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, compile_llvm, LLVMProgram)
|
||||
|
||||
Reference in New Issue
Block a user