From 0e93b9642af716580c8f188335b2c05c6956d80f Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 28 Jun 2023 19:21:01 +0000 Subject: [PATCH] hip matmul --- extra/gemm/hip_matmul.py | 62 +++++++++++++++++++++++++++++++++++++ extra/hip_wrapper.py | 15 +++++++++ tinygrad/runtime/ops_hip.py | 2 +- 3 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 extra/gemm/hip_matmul.py diff --git a/extra/gemm/hip_matmul.py b/extra/gemm/hip_matmul.py new file mode 100644 index 0000000000..bd40886396 --- /dev/null +++ b/extra/gemm/hip_matmul.py @@ -0,0 +1,62 @@ +import time +import numpy as np +from tinygrad.helpers import dtypes, getenv +from tinygrad.runtime.ops_hip import RawHIPBuffer, HIPProgram + +N = getenv("N", 1024) +assert N%16 == 0, "multiple of 16" +FLOPS = N*N*N*2 +BW = N*N*3*4 + +a = RawHIPBuffer(N*N, dtypes.float32) + +nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16) +nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16) +b = RawHIPBuffer.fromCPU(nb) +c = RawHIPBuffer.fromCPU(nc) + +prog = HIPProgram("test", f""" +typedef float float8 __attribute__((ext_vector_type(8))); +typedef _Float16 half16 __attribute__((ext_vector_type(16))); +extern "C" __global__ void test(float* c, __half* a, __half* b) {{ + const int gx = blockIdx.x; + const int gy = blockIdx.y; + + c += gx*16*{N} + gy*16; + a += gx*16*{N}; + b += gy*16; + + const int lIdx = threadIdx.x; + const int lane = lIdx % 16; + + half16 a_frag; + half16 b_frag; + float8 c_frag = {{}}; + + for (int k = 0; k < {N}; k += 16) {{ + for (int ele = 0; ele < 16; ++ele) {{ + a_frag[ele] = a[{N}*lane + (k+ele)]; + b_frag[ele] = b[(k+ele)*{N} + lane]; + }} + + c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag); + }} + + for (int ele = 0; ele < 8; ++ele) {{ + const int r = ele * 2 + (lIdx / 16); + c[{N}*r + lane] = c_frag[ele]; + }} +}}""") + +def timeit(fxn): + st = time.perf_counter() + et = fxn() + ret = time.perf_counter() - st # NOTE: et doesn't contain the launch overhead + #print(f"{ret*1e6:.2f} us") + return ret + +tm = min([timeit(lambda: prog([N//16, N//16, 1], [32, 1, 1], a, b, c, wait=True)) for _ in range(20)]) +na = a.toCPU().reshape(N,N) +comp = nb.astype(np.float32) @ nc.astype(np.float32) +print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s") +np.testing.assert_allclose(na, comp, atol=1e-2, rtol=1e-2) \ No newline at end of file diff --git a/extra/hip_wrapper.py b/extra/hip_wrapper.py index b06dd77902..cc0ccd5230 100644 --- a/extra/hip_wrapper.py +++ b/extra/hip_wrapper.py @@ -561,8 +561,23 @@ def hiprtcCompileProgram(prog, options): c_options = (ctypes.c_char_p * len(e_options))() c_options[:] = e_options status = _libhiprtc.hiprtcCompileProgram(prog, len(c_options), c_options) + if status == 6: print(hiprtcGetProgramLog(prog)) hipCheckStatus(status) +_libhiprtc.hiprtcGetProgramLogSize.restype = int +_libhiprtc.hiprtcGetProgramLogSize.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_size_t)] + +_libhiprtc.hiprtcGetProgramLog.restype = int +_libhiprtc.hiprtcGetProgramLog.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + +def hiprtcGetProgramLog(prog): + logsz = ctypes.c_size_t() + status = _libhiprtc.hiprtcGetProgramLogSize(prog, logsz) + hipCheckStatus(status) + logstr = ctypes.create_string_buffer(logsz.value) + status = _libhiprtc.hiprtcGetProgramLog(prog, logstr) + hipCheckStatus(status) + return logstr.value.decode() _libhiprtc.hiprtcGetCodeSize.restype = int _libhiprtc.hiprtcGetCodeSize.argtypes = [ctypes.c_void_p, # hiprtcProgram diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index b1cbcccfff..2ae227b4a7 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -55,7 +55,7 @@ class HIPCodegen(CStyleCodegen): lang = CStyleLanguage( kernel_prefix = "#define INFINITY (__builtin_inff())\nextern \"C\" __global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4", half_prekernel = "", - gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)], + gid = [f'blockIdx.{chr(120+i)}' for i in range(3)], lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]) HIPBuffer = Compiled(RawHIPBuffer, HIPCodegen, HIPProgram, hip.hipDeviceSynchronize)