bfloat16 in LLVM (enough for llama 2) (#1293)

* add bf16 support to LLVM

* bf16 read works
This commit is contained in:
George Hotz
2023-07-19 20:18:32 -07:00
committed by GitHub
parent 74e63fe4ee
commit ca77d6cd72
6 changed files with 50 additions and 5 deletions

View File

@@ -3,7 +3,7 @@ import numpy as np
from tinygrad.helpers import getenv, DType, DEBUG
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor, dtypes
from extra.utils import OSX
from extra.utils import OSX, temp
def _test_to_np(a:Tensor, np_dtype, target):
print(a)
@@ -26,6 +26,39 @@ def _test_add_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(l
def _test_mul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target)
def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target)
class TestBFloat16DType(unittest.TestCase):
def test_bf16_to_float(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000])
def test_float_to_bf16(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16, [100000])
# torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)
@unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM")
def test_bf16(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16)
t.realize()
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
@unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM")
def test_bf16_disk_write_read(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
t.to(f"disk:{temp('f32')}").realize()
# hack to "cast" f32 -> bf16
dat = open(temp('f32'), "rb").read()
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
with open(temp('bf16'), "wb") as f: f.write(adat)
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm().realize()
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
# for LLVM, it segfaults because it can't link to the casting function
@unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends")