changes for teenygrad (#3665)

* changes for teenygrad

* upd

* simpler test
This commit is contained in:
George Hotz
2024-03-09 15:30:34 -08:00
committed by GitHub
parent 89b8b5d549
commit 69ca7f7bf9
5 changed files with 38 additions and 38 deletions

View File

@@ -2,7 +2,7 @@ import pathlib, unittest
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import Timing, CI, fetch, temp
from tinygrad.helpers import Timing, CI, fetch, temp, getenv
def compare_weights_both(url):
import torch
@@ -216,5 +216,19 @@ class TestDiskTensor(unittest.TestCase):
np.testing.assert_array_equal(t.numpy(), np.array([3] * 10))
@unittest.skipIf(getenv("HIPCPU"), "no real HIP device exists in CI")
def test_bf16_disk_write_read(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
t.to(f"disk:{temp('f32')}").realize()
# hack to "cast" f32 -> bf16
with open(temp('f32'), "rb") as f: dat = f.read()
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
with open(temp('bf16'), "wb") as f: f.write(adat)
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm().realize()
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
if __name__ == "__main__":
unittest.main()