mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
accept filename decorator [pr] (#8049)
* accept filename decorator [pr] * add test for safe_load * bring old tar tests back
This commit is contained in:
@@ -134,15 +134,23 @@ class TestSafetensors(unittest.TestCase):
|
||||
for k in f.keys():
|
||||
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
|
||||
|
||||
def test_huggingface_enet_safetensors(self):
|
||||
# test a real file
|
||||
fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
|
||||
def _test_huggingface_enet_safetensors(self, fn):
|
||||
state_dict = safe_load(fn)
|
||||
assert len(state_dict.keys()) == 244
|
||||
assert 'blocks.2.2.se.conv_reduce.weight' in state_dict
|
||||
assert state_dict['blocks.0.0.bn1.num_batches_tracked'].numpy() == 276570
|
||||
assert state_dict['blocks.2.0.bn2.num_batches_tracked'].numpy() == 276570
|
||||
|
||||
def test_huggingface_enet_safetensors(self):
|
||||
# test a real file
|
||||
fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
|
||||
self._test_huggingface_enet_safetensors(fn)
|
||||
|
||||
def test_huggingface_enet_safetensors_fromurl(self):
|
||||
# test tensor input
|
||||
t = Tensor.from_url("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
|
||||
self._test_huggingface_enet_safetensors(t)
|
||||
|
||||
def test_metadata(self):
|
||||
metadata = {"hello": "world"}
|
||||
safe_save({}, temp('metadata.safetensors'), metadata)
|
||||
|
||||
@@ -1,8 +1,66 @@
|
||||
import unittest, tarfile, io, os, pathlib
|
||||
import unittest, tarfile, io, os, pathlib, tempfile
|
||||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.nn.state import tar_extract
|
||||
|
||||
class TestTarExtractFile(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.test_files = {
|
||||
'file1.txt': b'Hello, World!',
|
||||
'file2.bin': b'\x00\x01\x02\x03\x04',
|
||||
'empty_file.txt': b''
|
||||
}
|
||||
self.tar_path = os.path.join(self.test_dir, 'test.tar')
|
||||
with tarfile.open(self.tar_path, 'w') as tar:
|
||||
for filename, content in self.test_files.items():
|
||||
file_path = os.path.join(self.test_dir, filename)
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
tar.add(file_path, arcname=filename)
|
||||
|
||||
# Create invalid tar file
|
||||
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
|
||||
with open(self.invalid_tar_path, 'wb') as f:
|
||||
f.write(b'This is not a valid tar file')
|
||||
|
||||
def tearDown(self):
|
||||
for filename in self.test_files:
|
||||
os.remove(os.path.join(self.test_dir, filename))
|
||||
os.remove(self.tar_path)
|
||||
os.remove(self.invalid_tar_path)
|
||||
os.rmdir(self.test_dir)
|
||||
|
||||
def test_tar_extract_returns_dict(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertIsInstance(result, dict)
|
||||
|
||||
def test_tar_extract_correct_keys(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
|
||||
|
||||
def test_tar_extract_content_size(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
for filename, content in self.test_files.items():
|
||||
self.assertEqual(len(result[filename]), len(content))
|
||||
|
||||
def test_tar_extract_content_values(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
for filename, content in self.test_files.items():
|
||||
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
|
||||
|
||||
def test_tar_extract_empty_file(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertEqual(len(result['empty_file.txt']), 0)
|
||||
|
||||
def test_tar_extract_non_existent_file(self):
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
tar_extract('non_existent_file.tar')
|
||||
|
||||
def test_tar_extract_invalid_file(self):
|
||||
with self.assertRaises(tarfile.ReadError):
|
||||
tar_extract(self.invalid_tar_path)
|
||||
|
||||
class TestTarExtractPAX(unittest.TestCase):
|
||||
tar_format = tarfile.PAX_FORMAT
|
||||
max_link_len = 1000_000
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os, json, pathlib, zipfile, pickle, tarfile, struct, functools, io
|
||||
from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable
|
||||
from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable, TypeVar
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
|
||||
@@ -35,16 +35,22 @@ safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dt
|
||||
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
|
||||
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
||||
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||||
R = TypeVar('R')
|
||||
def accept_filename(func: Callable[[Tensor], R]) -> Callable[[Union[Tensor, str, pathlib.Path]], R]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> R: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
|
||||
return wrapper
|
||||
|
||||
@accept_filename
|
||||
def safe_load_metadata(t:Tensor) -> Tuple[Tensor, int, Any]:
|
||||
"""
|
||||
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
|
||||
"""
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:8].bitcast(dtypes.int64).item()
|
||||
assert isinstance(json_len, int)
|
||||
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Loads a .safetensor file from disk, returning the state_dict.
|
||||
|
||||
@@ -157,6 +163,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
|
||||
else: v.replace(state_dict[k].to(v.device)).realize()
|
||||
if consume: del state_dict[k]
|
||||
|
||||
@accept_filename
|
||||
def tar_extract(t: Tensor) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
|
||||
@@ -287,6 +294,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
||||
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
|
||||
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
|
||||
|
||||
@accept_filename
|
||||
def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
|
||||
"""
|
||||
Loads a gguf file from a tensor.
|
||||
|
||||
Reference in New Issue
Block a user