mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
refactor amd render_kernel (#6223)
* refactor amd render_kernel * fix spacing * add half alias back * use itemsize * 8 insted of fixed values * reverting becasue it broke as no longer 32 was default * remove comment * remove nested tuples * hotfix: prefix.append * hotfix2: is not None * more diff cleanups * hotfix 4: spacing changes must not be in the same diff * revert wmma dtype rendering --------- Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
@@ -357,11 +357,6 @@ def _make_hip_code_for_op():
|
||||
return cast_bf16
|
||||
return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
|
||||
|
||||
def _make_hip_dtype(base_type, name, cnt):
|
||||
elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
|
||||
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
|
||||
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
|
||||
|
||||
class AMDRenderer(CStyleLanguage):
|
||||
device = "AMD"
|
||||
shared_max = 65536
|
||||
@@ -370,7 +365,7 @@ class AMDRenderer(CStyleLanguage):
|
||||
# language options
|
||||
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
|
||||
ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
|
||||
for dt, n in [("float", 32), ("double", 64), ("_Float16", 16)]
|
||||
for dt, n in [(dtype.name, dtype.itemsize * 8) for dtype in [dtypes.float, dtypes.double, dtypes.half]]
|
||||
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]]
|
||||
|
||||
kernel_prefix = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
|
||||
@@ -385,9 +380,13 @@ class AMDRenderer(CStyleLanguage):
|
||||
uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt
|
||||
type_map = {dtypes.bfloat16: "hip_bfloat16"}
|
||||
|
||||
def render_vector_prefix(self, dtype:DType) -> str:
|
||||
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
|
||||
return f"typedef {scal} {vec} __attribute__((ext_vector_type({dtype.count})));\nstatic inline __attribute__((device)) "+ \
|
||||
f"{vec} make_{vec}({', '.join([f'{scal} {x}' for x in _nms[:dtype.count]])}) {{ return {{ {', '.join(_nms[:dtype.count])} }}; }}"
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"]
|
||||
vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)]
|
||||
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
|
||||
|
||||
# TODO: add BF16 vec dts
|
||||
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
|
||||
@@ -407,11 +406,7 @@ static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat1
|
||||
static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
|
||||
""")
|
||||
|
||||
if any(uop.dtype == dtypes.half for uop in uops):
|
||||
prefix.append("#define half _Float16")
|
||||
vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)]
|
||||
|
||||
prefix += [_make_hip_dtype(*x) for x in vec_dts]
|
||||
for dtype in dedup(uop.dtype for uop in uops if uop.dtype is not None and uop.dtype.count > 1): prefix.append(self.render_vector_prefix(dtype))
|
||||
|
||||
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
|
||||
|
||||
Reference in New Issue
Block a user