tiny version of amd_hip_bfloat16 (#3868)

* add src_dtype

* add maker

* add bfloat16

* simpler
This commit is contained in:
qazal
2024-03-22 17:37:30 +02:00
committed by GitHub
parent 82ce60e172
commit 4a27ce6ec9

View File

@@ -259,6 +259,16 @@ code_for_op_hip = {
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
}
def _make_hip_code_for_op():
def wrapper(key, func):
def cast_bf16(*args):
if args[-1] == dtypes.bfloat16:
operands = tuple(f"(float)({arg})" for arg in (args[1:-1] if key is TernaryOps.WHERE else args[:-1]))
return f"(hip_bfloat16)({func(*(((args[0],) if key is TernaryOps.WHERE else ()) + operands), dtypes.float)})"
return func(*args)
return cast_bf16
return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
def _make_hip_dtype(base_type, name, cnt):
nms = "xyzwabcdefghijkl"[:cnt]
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
@@ -299,7 +309,7 @@ class HIPLanguage(CStyleLanguage):
}\nextern "C" __attribute__((global))"""
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_hip}
code_for_op = _make_hip_code_for_op()
smem_prefix = "__attribute__((shared))"
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
@@ -310,8 +320,31 @@ class HIPLanguage(CStyleLanguage):
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix = ["#include <hip/hip_common.h>\n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))",
"typedef long unsigned int size_t;"]
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include <hip/amd_detail/amd_hip_bfloat16.h>")
else: prefix.append('\n'.join(_make_hip_dtype(*x) for x in [("float", "float", 2), ("float", "float", 4),
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
struct hip_bfloat16 {
unsigned short data;
__attribute__((device)) hip_bfloat16(float val) {
union { float fp32; unsigned int u32; } u = {val};
if (~u.u32 & 0x7f800000) {
u.u32 += 0x7fff + ((u.u32 >> 16) & 1);
} else if (u.u32 & 0xffff) {
u.u32 |= 0x10000;
}
data = (u.u32 >> 16);
}
__attribute__((device)) operator float() const {
unsigned int uval = data << 16;
return *reinterpret_cast<float*>(&uval);
}
};
static __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) {
return ((float)a) < ((float)b);
}
static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
return ((float)a) == ((float)b);
}
""")
prefix.append('\n'.join(_make_hip_dtype(*x) for x in [("float", "float", 2), ("float", "float", 4),
("signed int", "int", 4), ("signed int", "int", 2)]))
return super().render_kernel(function_name, kernel, bufs, uops, prefix)