accept filename decorator [pr] (#8049)

* accept filename decorator [pr]

* add test for safe_load

* bring old tar tests back
This commit is contained in:
George Hotz
2024-12-05 11:40:59 +08:00
committed by GitHub
parent c3187087f7
commit df18e7cc37
3 changed files with 82 additions and 8 deletions

View File

@@ -134,15 +134,23 @@ class TestSafetensors(unittest.TestCase):
for k in f.keys():
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
def test_huggingface_enet_safetensors(self):
# test a real file
fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
def _test_huggingface_enet_safetensors(self, fn):
state_dict = safe_load(fn)
assert len(state_dict.keys()) == 244
assert 'blocks.2.2.se.conv_reduce.weight' in state_dict
assert state_dict['blocks.0.0.bn1.num_batches_tracked'].numpy() == 276570
assert state_dict['blocks.2.0.bn2.num_batches_tracked'].numpy() == 276570
def test_huggingface_enet_safetensors(self):
# test a real file
fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
self._test_huggingface_enet_safetensors(fn)
def test_huggingface_enet_safetensors_fromurl(self):
# test tensor input
t = Tensor.from_url("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
self._test_huggingface_enet_safetensors(t)
def test_metadata(self):
metadata = {"hello": "world"}
safe_save({}, temp('metadata.safetensors'), metadata)

View File

@@ -1,8 +1,66 @@
import unittest, tarfile, io, os, pathlib
import unittest, tarfile, io, os, pathlib, tempfile
import numpy as np
from tinygrad import Tensor
from tinygrad.nn.state import tar_extract
class TestTarExtractFile(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.test_files = {
'file1.txt': b'Hello, World!',
'file2.bin': b'\x00\x01\x02\x03\x04',
'empty_file.txt': b''
}
self.tar_path = os.path.join(self.test_dir, 'test.tar')
with tarfile.open(self.tar_path, 'w') as tar:
for filename, content in self.test_files.items():
file_path = os.path.join(self.test_dir, filename)
with open(file_path, 'wb') as f:
f.write(content)
tar.add(file_path, arcname=filename)
# Create invalid tar file
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
with open(self.invalid_tar_path, 'wb') as f:
f.write(b'This is not a valid tar file')
def tearDown(self):
for filename in self.test_files:
os.remove(os.path.join(self.test_dir, filename))
os.remove(self.tar_path)
os.remove(self.invalid_tar_path)
os.rmdir(self.test_dir)
def test_tar_extract_returns_dict(self):
result = tar_extract(self.tar_path)
self.assertIsInstance(result, dict)
def test_tar_extract_correct_keys(self):
result = tar_extract(self.tar_path)
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
def test_tar_extract_content_size(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
self.assertEqual(len(result[filename]), len(content))
def test_tar_extract_content_values(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
def test_tar_extract_empty_file(self):
result = tar_extract(self.tar_path)
self.assertEqual(len(result['empty_file.txt']), 0)
def test_tar_extract_non_existent_file(self):
with self.assertRaises(FileNotFoundError):
tar_extract('non_existent_file.tar')
def test_tar_extract_invalid_file(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(self.invalid_tar_path)
class TestTarExtractPAX(unittest.TestCase):
tar_format = tarfile.PAX_FORMAT
max_link_len = 1000_000

View File

@@ -1,5 +1,5 @@
import os, json, pathlib, zipfile, pickle, tarfile, struct, functools, io
from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable
from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable, TypeVar
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
@@ -35,16 +35,22 @@ safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dt
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
R = TypeVar('R')
def accept_filename(func: Callable[[Tensor], R]) -> Callable[[Union[Tensor, str, pathlib.Path]], R]:
@functools.wraps(func)
def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> R: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
return wrapper
@accept_filename
def safe_load_metadata(t:Tensor) -> Tuple[Tensor, int, Any]:
"""
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
"""
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
json_len = t[0:8].bitcast(dtypes.int64).item()
assert isinstance(json_len, int)
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> Dict[str, Tensor]:
"""
Loads a .safetensor file from disk, returning the state_dict.
@@ -157,6 +163,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
else: v.replace(state_dict[k].to(v.device)).realize()
if consume: del state_dict[k]
@accept_filename
def tar_extract(t: Tensor) -> Dict[str, Tensor]:
"""
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
@@ -287,6 +294,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
@accept_filename
def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
"""
Loads a gguf file from a tensor.