mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-07 13:15:01 -05:00
bfloat16 in LLVM (enough for llama 2) (#1293)
* add bf16 support to LLVM * bf16 read works
This commit is contained in:
@@ -3,7 +3,7 @@ import numpy as np
|
||||
from tinygrad.helpers import getenv, DType, DEBUG
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
from extra.utils import OSX
|
||||
from extra.utils import OSX, temp
|
||||
|
||||
def _test_to_np(a:Tensor, np_dtype, target):
|
||||
print(a)
|
||||
@@ -26,6 +26,39 @@ def _test_add_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(l
|
||||
def _test_mul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target)
|
||||
def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target)
|
||||
|
||||
|
||||
class TestBFloat16DType(unittest.TestCase):
|
||||
def test_bf16_to_float(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000])
|
||||
|
||||
def test_float_to_bf16(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16, [100000])
|
||||
|
||||
# torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM")
|
||||
def test_bf16(self):
|
||||
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16)
|
||||
t.realize()
|
||||
back = t.cast(dtypes.float32)
|
||||
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM")
|
||||
def test_bf16_disk_write_read(self):
|
||||
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
|
||||
t.to(f"disk:{temp('f32')}").realize()
|
||||
|
||||
# hack to "cast" f32 -> bf16
|
||||
dat = open(temp('f32'), "rb").read()
|
||||
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
|
||||
with open(temp('bf16'), "wb") as f: f.write(adat)
|
||||
|
||||
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm().realize()
|
||||
back = t.cast(dtypes.float32)
|
||||
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
|
||||
|
||||
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
|
||||
# for LLVM, it segfaults because it can't link to the casting function
|
||||
@unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends")
|
||||
|
||||
Reference in New Issue
Block a user