minor cleanup of test_disk_tensor (#3112)

This commit is contained in:
chenyu
2024-01-13 20:54:58 -05:00
committed by GitHub
parent 9c73d2724f
commit c658aa4fbf
2 changed files with 10 additions and 11 deletions

View File

@@ -1,5 +1,4 @@
import pathlib
import unittest
import pathlib, unittest
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
@@ -33,7 +32,7 @@ class TestTorchLoad(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "TORCH", "torch doesn't support the way we load bfloat (cast to uint32)")
def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
# TODO: support pytorch tar format with minimal lines
# pytorch tar format
def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
test_fn = pathlib.Path(__file__).parents[2] / "weights/LLaMA/7B/consolidated.00.pth"
@@ -80,7 +79,7 @@ class TestSafetensors(unittest.TestCase):
assert os.path.getsize(fn) == 8+0x40+(10*10*4)
from safetensors import safe_open
with safe_open(fn, framework="pt", device="cpu") as f:
assert sorted(list(f.keys())) == sorted(list(state_dict.keys()))
assert sorted(f.keys()) == sorted(state_dict.keys())
for k in f.keys():
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
@@ -90,14 +89,14 @@ class TestSafetensors(unittest.TestCase):
state_dict = get_state_dict(model)
safe_save(state_dict, temp("eff0"))
state_dict_loaded = safe_load(temp("eff0"))
assert sorted(list(state_dict_loaded.keys())) == sorted(list(state_dict.keys()))
assert sorted(state_dict_loaded.keys()) == sorted(state_dict.keys())
for k,v in state_dict.items():
np.testing.assert_array_equal(v.numpy(), state_dict_loaded[k].numpy())
# load with the real safetensors
from safetensors import safe_open
with safe_open(temp("eff0"), framework="pt", device="cpu") as f:
assert sorted(list(f.keys())) == sorted(list(state_dict.keys()))
assert sorted(f.keys()) == sorted(state_dict.keys())
for k in f.keys():
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
@@ -151,14 +150,14 @@ class TestDiskTensor(unittest.TestCase):
outdisk.realize()
del out, outdisk
import struct
# test file
with open(temp("dt2"), "rb") as f:
assert f.read() == b"\x00\x00\x80\x3F" * 100
assert f.read() == struct.pack('<f', 1.0) * 100 == b"\x00\x00\x80\x3F" * 100
# test load alt
reloaded = Tensor.empty(10, 10, device=f"disk:{temp('dt2')}")
out = reloaded.numpy()
assert np.all(out == 1.)
np.testing.assert_almost_equal(reloaded.numpy(), np.ones((10, 10)))
def test_assign_slice(self):
def assign(x,s,y): x[s] = y

View File

@@ -13,8 +13,8 @@ inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
return (t, json_len, json.loads(t[8:8+json_len].numpy().tobytes()))
json_len = t[0:1].cast(dtypes.int64).item()
return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
t, json_len, metadata = safe_load_metadata(fn)