From 3918f6eea042ceca033f3f87b5b2f48c2d63ffd2 Mon Sep 17 00:00:00 2001 From: ignaciosica Date: Mon, 26 Aug 2024 13:28:36 -0300 Subject: [PATCH] refactor amd render_kernel (#6223) * refactor amd render_kernel * fix spacing * add half alias back * use itemsize * 8 insted of fixed values * reverting becasue it broke as no longer 32 was default * remove comment * remove nested tuples * hotfix: prefix.append * hotfix2: is not None * more diff cleanups * hotfix 4: spacing changes must not be in the same diff * revert wmma dtype rendering --------- Co-authored-by: qazal --- tinygrad/renderer/cstyle.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 5db175b063..cd2b01a008 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -357,11 +357,6 @@ def _make_hip_code_for_op(): return cast_bf16 return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() } -def _make_hip_dtype(base_type, name, cnt): - elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]]) - return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \ - f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}" - class AMDRenderer(CStyleLanguage): device = "AMD" shared_max = 65536 @@ -370,7 +365,7 @@ class AMDRenderer(CStyleLanguage): # language options ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]] ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr) - for dt, n in [("float", 32), ("double", 64), ("_Float16", 16)] + for dt, n in [(dtype.name, dtype.itemsize * 8) for dtype in [dtypes.float, dtypes.double, dtypes.half]] for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]] kernel_prefix = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml) @@ -385,9 +380,13 @@ class AMDRenderer(CStyleLanguage): uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt type_map = {dtypes.bfloat16: "hip_bfloat16"} + def render_vector_prefix(self, dtype:DType) -> str: + vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar()) + return f"typedef {scal} {vec} __attribute__((ext_vector_type({dtype.count})));\nstatic inline __attribute__((device)) "+ \ + f"{vec} make_{vec}({', '.join([f'{scal} {x}' for x in _nms[:dtype.count]])}) {{ return {{ {', '.join(_nms[:dtype.count])} }}; }}" + def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: - prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"] - vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)] + prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"] # TODO: add BF16 vec dts if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append(""" @@ -407,11 +406,7 @@ static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat1 static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); } """) - if any(uop.dtype == dtypes.half for uop in uops): - prefix.append("#define half _Float16") - vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)] - - prefix += [_make_hip_dtype(*x) for x in vec_dts] + for dtype in dedup(uop.dtype for uop in uops if uop.dtype is not None and uop.dtype.count > 1): prefix.append(self.render_vector_prefix(dtype)) for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")