mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
tiny version of amd_hip_bfloat16 (#3868)
* add src_dtype * add maker * add bfloat16 * simpler
This commit is contained in:
@@ -259,6 +259,16 @@ code_for_op_hip = {
|
||||
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
}
|
||||
|
||||
def _make_hip_code_for_op():
|
||||
def wrapper(key, func):
|
||||
def cast_bf16(*args):
|
||||
if args[-1] == dtypes.bfloat16:
|
||||
operands = tuple(f"(float)({arg})" for arg in (args[1:-1] if key is TernaryOps.WHERE else args[:-1]))
|
||||
return f"(hip_bfloat16)({func(*(((args[0],) if key is TernaryOps.WHERE else ()) + operands), dtypes.float)})"
|
||||
return func(*args)
|
||||
return cast_bf16
|
||||
return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
|
||||
|
||||
def _make_hip_dtype(base_type, name, cnt):
|
||||
nms = "xyzwabcdefghijkl"[:cnt]
|
||||
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
|
||||
@@ -299,7 +309,7 @@ class HIPLanguage(CStyleLanguage):
|
||||
}\nextern "C" __attribute__((global))"""
|
||||
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
|
||||
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_hip}
|
||||
code_for_op = _make_hip_code_for_op()
|
||||
smem_prefix = "__attribute__((shared))"
|
||||
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
|
||||
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
|
||||
@@ -310,8 +320,31 @@ class HIPLanguage(CStyleLanguage):
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix = ["#include <hip/hip_common.h>\n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))",
|
||||
"typedef long unsigned int size_t;"]
|
||||
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include <hip/amd_detail/amd_hip_bfloat16.h>")
|
||||
else: prefix.append('\n'.join(_make_hip_dtype(*x) for x in [("float", "float", 2), ("float", "float", 4),
|
||||
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
|
||||
struct hip_bfloat16 {
|
||||
unsigned short data;
|
||||
__attribute__((device)) hip_bfloat16(float val) {
|
||||
union { float fp32; unsigned int u32; } u = {val};
|
||||
if (~u.u32 & 0x7f800000) {
|
||||
u.u32 += 0x7fff + ((u.u32 >> 16) & 1);
|
||||
} else if (u.u32 & 0xffff) {
|
||||
u.u32 |= 0x10000;
|
||||
}
|
||||
data = (u.u32 >> 16);
|
||||
}
|
||||
__attribute__((device)) operator float() const {
|
||||
unsigned int uval = data << 16;
|
||||
return *reinterpret_cast<float*>(&uval);
|
||||
}
|
||||
};
|
||||
static __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) {
|
||||
return ((float)a) < ((float)b);
|
||||
}
|
||||
static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
||||
return ((float)a) == ((float)b);
|
||||
}
|
||||
""")
|
||||
prefix.append('\n'.join(_make_hip_dtype(*x) for x in [("float", "float", 2), ("float", "float", 4),
|
||||
("signed int", "int", 4), ("signed int", "int", 2)]))
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user