HIP backend fixes (#1336)

* feat: hip trains cifar

* feat: test_dtype fixes
This commit is contained in:
wozeparrot
2023-07-24 11:16:57 -04:00
committed by GitHub
parent 086382b64e
commit 51173f0a48

View File

@@ -14,18 +14,17 @@ if DEBUG >= 5:
# The default HIP stream is used for everything.
class RawHIPBuffer(RawBufferCopyInOut):
def __init__(self, size, dtype):
self.buf_sz = size * dtype.itemsize
super().__init__(size, dtype, hip.hipMalloc(self.buf_sz))
def _copyin(self, x:np.ndarray): hip.hipMemcpyAsync_htod(self._buf, x.ctypes.data, self.buf_sz, 0)
def _copyout(self, x:np.ndarray): hip.hipMemcpyAsync_dtoh(x.ctypes.data, self._buf, self.buf_sz, 0)
def __init__(self, size, dtype): super().__init__(size, dtype, hip.hipMalloc(size * dtype.itemsize))
def __del__(self): hip.hipFree(self._buf)
def _copyin(self, x:np.ndarray): hip.hipMemcpyAsync_htod(self._buf, x.ctypes.data, self.size * self.dtype.itemsize, 0)
def _copyout(self, x:np.ndarray): hip.hipMemcpy_dtoh(x.ctypes.data, self._buf, self.size * self.dtype.itemsize)
class HIPProgram:
def __init__(self, name:str, prg:str, binary=False):
try:
if not binary:
prog = hip.hiprtcCreateProgram(prg, name, [], [])
device_properties = hip.hipGetDeviceProperties(0)
device_properties = hip.hipGetDeviceProperties(hip.hipGetDevice())
hip.hiprtcCompileProgram(prog, [f'--offload-arch={device_properties.gcnArchName}'])
prg = hip.hiprtcGetCode(prog)
except Exception as e:
@@ -53,8 +52,24 @@ class HIPProgram:
class HIPCodegen(CStyleCodegen):
lang = CStyleLanguage(
kernel_prefix = "#define INFINITY (__builtin_inff())\nextern \"C\" __global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
half_prekernel = "",
kernel_prefix = "#include <hip/hip_common.h>\n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))" + """
__device__ float4 max(float4 x, float4 y) { return float4(max(x.x, y.x), max(x.y, y.y), max(x.z, y.z), max(x.w, y.w)); }
__device__ float4 pow(float x, float4 y) { return float4(pow(x, y.x), pow(x, y.y), pow(x, y.z), pow(x, y.w)); }
__device__ float4 pow(float4 x, float4 y) { return float4(pow(x.x, y.x), pow(x.y, y.y), pow(x.z, y.z), pow(x.w, y.w)); }
__device__ float4 log2(float4 x) { return float4(log2(x.x), log2(x.y), log2(x.z), log2(x.w)); }
__device__ float4 exp2(float4 x) { return float4(exp2(x.x), exp2(x.y), exp2(x.z), exp2(x.w)); }
__device__ float4 sin(float4 x) { return float4(sin(x.x), sin(x.y), sin(x.z), sin(x.w)); }
extern "C" __global__
""",
smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4", uses_vload=True,
half_prekernel = "#include <hip/hip_fp16.h>\nusing half4 = HIP_vector_type<half, 4>;" + """
__device__ float vload_half(size_t offset, const half *p) { return (float)*(p + offset); }
__device__ float2 vload_half2(size_t offset, const half *p) { return make_float2((float)*(p + offset*2), (float)*(p + offset*2 + 1)); }
__device__ float4 vload_half4(size_t offset, const half *p) { return make_float4((float)*(p + offset*4), (float)*(p + offset*4 + 1), (float)*(p + offset*4 + 2), (float)*(p + offset*4 + 3)); }
__device__ void vstore_half(float data, size_t offset, half *p) { *(p + offset) = (half)data; }
__device__ void vstore_half2(float2 data, size_t offset, half *p) { *(p + offset*2) = (half)data.x; *(p + offset*2 + 1) = (half)data.y; }
__device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; }
""",
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])