mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
minor cleanup of test_disk_tensor (#3112)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import pathlib
|
||||
import unittest
|
||||
import pathlib, unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
@@ -33,7 +32,7 @@ class TestTorchLoad(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "torch doesn't support the way we load bfloat (cast to uint32)")
|
||||
def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
|
||||
|
||||
# TODO: support pytorch tar format with minimal lines
|
||||
# pytorch tar format
|
||||
def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
|
||||
|
||||
test_fn = pathlib.Path(__file__).parents[2] / "weights/LLaMA/7B/consolidated.00.pth"
|
||||
@@ -80,7 +79,7 @@ class TestSafetensors(unittest.TestCase):
|
||||
assert os.path.getsize(fn) == 8+0x40+(10*10*4)
|
||||
from safetensors import safe_open
|
||||
with safe_open(fn, framework="pt", device="cpu") as f:
|
||||
assert sorted(list(f.keys())) == sorted(list(state_dict.keys()))
|
||||
assert sorted(f.keys()) == sorted(state_dict.keys())
|
||||
for k in f.keys():
|
||||
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
|
||||
|
||||
@@ -90,14 +89,14 @@ class TestSafetensors(unittest.TestCase):
|
||||
state_dict = get_state_dict(model)
|
||||
safe_save(state_dict, temp("eff0"))
|
||||
state_dict_loaded = safe_load(temp("eff0"))
|
||||
assert sorted(list(state_dict_loaded.keys())) == sorted(list(state_dict.keys()))
|
||||
assert sorted(state_dict_loaded.keys()) == sorted(state_dict.keys())
|
||||
for k,v in state_dict.items():
|
||||
np.testing.assert_array_equal(v.numpy(), state_dict_loaded[k].numpy())
|
||||
|
||||
# load with the real safetensors
|
||||
from safetensors import safe_open
|
||||
with safe_open(temp("eff0"), framework="pt", device="cpu") as f:
|
||||
assert sorted(list(f.keys())) == sorted(list(state_dict.keys()))
|
||||
assert sorted(f.keys()) == sorted(state_dict.keys())
|
||||
for k in f.keys():
|
||||
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
|
||||
|
||||
@@ -151,14 +150,14 @@ class TestDiskTensor(unittest.TestCase):
|
||||
outdisk.realize()
|
||||
del out, outdisk
|
||||
|
||||
import struct
|
||||
# test file
|
||||
with open(temp("dt2"), "rb") as f:
|
||||
assert f.read() == b"\x00\x00\x80\x3F" * 100
|
||||
assert f.read() == struct.pack('<f', 1.0) * 100 == b"\x00\x00\x80\x3F" * 100
|
||||
|
||||
# test load alt
|
||||
reloaded = Tensor.empty(10, 10, device=f"disk:{temp('dt2')}")
|
||||
out = reloaded.numpy()
|
||||
assert np.all(out == 1.)
|
||||
np.testing.assert_almost_equal(reloaded.numpy(), np.ones((10, 10)))
|
||||
|
||||
def test_assign_slice(self):
|
||||
def assign(x,s,y): x[s] = y
|
||||
|
||||
@@ -13,8 +13,8 @@ inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
||||
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
|
||||
return (t, json_len, json.loads(t[8:8+json_len].numpy().tobytes()))
|
||||
json_len = t[0:1].cast(dtypes.int64).item()
|
||||
return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
t, json_len, metadata = safe_load_metadata(fn)
|
||||
|
||||
Reference in New Issue
Block a user