refactor amd render_kernel (#6223)

* refactor amd render_kernel

* fix spacing

* add half alias back

* use itemsize * 8 insted of fixed values

* reverting becasue it broke as no longer 32 was default

* remove comment

* remove nested tuples

* hotfix: prefix.append

* hotfix2: is not None

* more diff cleanups

* hotfix 4: spacing changes must not be in the same diff

* revert wmma dtype rendering

---------

Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
ignaciosica
2024-08-26 13:28:36 -03:00
committed by GitHub
parent 3132449086
commit 3918f6eea0

View File

@@ -357,11 +357,6 @@ def _make_hip_code_for_op():
return cast_bf16
return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
def _make_hip_dtype(base_type, name, cnt):
elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
class AMDRenderer(CStyleLanguage):
device = "AMD"
shared_max = 65536
@@ -370,7 +365,7 @@ class AMDRenderer(CStyleLanguage):
# language options
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
for dt, n in [("float", 32), ("double", 64), ("_Float16", 16)]
for dt, n in [(dtype.name, dtype.itemsize * 8) for dtype in [dtypes.float, dtypes.double, dtypes.half]]
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]]
kernel_prefix = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
@@ -385,9 +380,13 @@ class AMDRenderer(CStyleLanguage):
uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt
type_map = {dtypes.bfloat16: "hip_bfloat16"}
def render_vector_prefix(self, dtype:DType) -> str:
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
return f"typedef {scal} {vec} __attribute__((ext_vector_type({dtype.count})));\nstatic inline __attribute__((device)) "+ \
f"{vec} make_{vec}({', '.join([f'{scal} {x}' for x in _nms[:dtype.count]])}) {{ return {{ {', '.join(_nms[:dtype.count])} }}; }}"
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"]
vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)]
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
# TODO: add BF16 vec dts
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
@@ -407,11 +406,7 @@ static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat1
static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
""")
if any(uop.dtype == dtypes.half for uop in uops):
prefix.append("#define half _Float16")
vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)]
prefix += [_make_hip_dtype(*x) for x in vec_dts]
for dtype in dedup(uop.dtype for uop in uops if uop.dtype is not None and uop.dtype.count > 1): prefix.append(self.render_vector_prefix(dtype))
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")