hotfix: hip bfloat formatting

This commit is contained in:
George Hotz
2024-03-22 11:52:05 -07:00
parent 54dc48aa47
commit 0c197b9cf3

View File

@@ -325,24 +325,16 @@ struct hip_bfloat16 {
unsigned short data;
__attribute__((device)) hip_bfloat16(float val) {
union { float fp32; unsigned int u32; } u = {val};
if (~u.u32 & 0x7f800000) {
u.u32 += 0x7fff + ((u.u32 >> 16) & 1);
} else if (u.u32 & 0xffff) {
u.u32 |= 0x10000;
}
if (~u.u32 & 0x7f800000) { u.u32 += 0x7fff + ((u.u32 >> 16) & 1); } else if (u.u32 & 0xffff) { u.u32 |= 0x10000; }
data = (u.u32 >> 16);
}
__attribute__((device)) operator float() const {
unsigned int uval = data << 16;
return *reinterpret_cast<float*>(&uval);
unsigned int uval = data << 16;
return *reinterpret_cast<float*>(&uval);
}
};
static __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) {
return ((float)a) < ((float)b);
}
static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
return ((float)a) == ((float)b);
}
static __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
""")
prefix.append('\n'.join(_make_hip_dtype(*x) for x in [("float", "float", 2), ("float", "float", 4),
("signed int", "int", 4), ("signed int", "int", 2)]))