diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1cc7f7342e..f10a0e64cb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -466,7 +466,7 @@ jobs: run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors)' --ignore=test/external --ignore=test/models --durations=20 - name: Run pytest (hip) if: matrix.backend=='hip' - run: python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/imported/test_indexing.py --durations=20 + run: python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/imported/test_indexing.py test/external/external_test_hip_compile.py --durations=20 #testunicorn: # name: ARM64 unicorn Test diff --git a/test/external/external_test_hip_compile.py b/test/external/external_test_hip_compile.py new file mode 100644 index 0000000000..dd8baf87fa --- /dev/null +++ b/test/external/external_test_hip_compile.py @@ -0,0 +1,39 @@ +import time, unittest +from tinygrad.runtime.driver.hip_comgr import compile_hip +from tinygrad import Tensor +from tinygrad.device import Device +from tinygrad.realize import create_schedule +from tinygrad.codegen.linearizer import Linearizer + +class TestHIPCompileSpeed(unittest.TestCase): + @unittest.skipIf(Device.DEFAULT != "HIP", "only run on HIP") + def test_hip_compile(self): + a, b = Tensor([1,2,3,4,5]), Tensor([1,2,3,4,5]) + out = a + b + lin = Linearizer(create_schedule([out.lazydata])[-1].ast) + lin.linearize() + + reference = """ +#include + typedef long unsigned int size_t; + extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int); + extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int); + extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int); + extern "C" __attribute__((global))void {name}(int* data0, const int* data1, const int* data2) {{ + int gidx0 = __ockl_get_group_id(0); /* 5 */ + int val0 = data1[gidx0]; + int val1 = data2[gidx0]; + data0[gidx0] = (val0+val1); + }} + """ + + def time_compile(code): + st = time.perf_counter() + compile_hip(code) + return (time.perf_counter() - st) * 1000 + + tinygrad_tm = min([time_compile(Device[Device.DEFAULT].compiler.render(f"test{i}", lin.uops)) for i in range(10)]) + ref_tm = min([time_compile(reference.format(name=f"test{i}")) for i in range(10)]) + print(f"tinygrad {tinygrad_tm:6.2f} ms") + print(f"reference {ref_tm:6.2f} ms") + assert (tinygrad_tm - ref_tm) <= 10 diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index c6d73c9315..56eeda23be 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -266,10 +266,8 @@ def _make_hip_dtype(base_type, name, cnt): ") { return {" + ', '.join(nms) + "}; }" class HIPLanguage(CStyleLanguage): - kernel_prefix = "#include \n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))" + """ - typedef long unsigned int size_t; + kernel_prefix = """ #define half _Float16 - #include extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int); extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int); @@ -291,10 +289,9 @@ class HIPLanguage(CStyleLanguage): __attribute__((device)) __attribute__((pure)) _Float16 __ocml_log2_f16(_Float16); __attribute__((device)) _Float16 __ocml_sin_f16(_Float16); __attribute__((device)) __attribute__((const)) _Float16 __ocml_sqrt_f16(_Float16); - }\n""" + '\n'.join([ - _make_hip_dtype(*x) for x in [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16), - ("float", "float", 8) - ]]) + """ + }\n""" + '\n'.join([_make_hip_dtype(*x) for x in [ + ("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16), + ("float", "float", 8)]]) + """ static __attribute__((device)) half8 __hip_wmma_f16_f16(half16 a, half16 b, half8 c) { half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; } c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false); @@ -309,4 +306,13 @@ class HIPLanguage(CStyleLanguage): float4 = "make_float4" uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt type_map = {dtypes.bfloat16: "hip_bfloat16"} + + def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: + prefix = ["#include \n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))", + "typedef long unsigned int size_t;"] + if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include ") + else: prefix.append('\n'.join(_make_hip_dtype(*x) for x in [("float", "float", 2), ("float", "float", 4), + ("signed int", "int", 4), ("signed int", "int", 2)])) + return super().render_kernel(function_name, kernel, bufs, uops, prefix) + HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())