From df18e7cc37c2b9d405687d1d3df540c5d718c810 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 5 Dec 2024 11:40:59 +0800 Subject: [PATCH] accept filename decorator [pr] (#8049) * accept filename decorator [pr] * add test for safe_load * bring old tar tests back --- test/unit/test_disk_tensor.py | 14 ++++++-- test/unit/test_tar.py | 60 ++++++++++++++++++++++++++++++++++- tinygrad/nn/state.py | 16 +++++++--- 3 files changed, 82 insertions(+), 8 deletions(-) diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 7c1d239f16..14c6dbb62e 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -134,15 +134,23 @@ class TestSafetensors(unittest.TestCase): for k in f.keys(): np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy()) - def test_huggingface_enet_safetensors(self): - # test a real file - fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors") + def _test_huggingface_enet_safetensors(self, fn): state_dict = safe_load(fn) assert len(state_dict.keys()) == 244 assert 'blocks.2.2.se.conv_reduce.weight' in state_dict assert state_dict['blocks.0.0.bn1.num_batches_tracked'].numpy() == 276570 assert state_dict['blocks.2.0.bn2.num_batches_tracked'].numpy() == 276570 + def test_huggingface_enet_safetensors(self): + # test a real file + fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors") + self._test_huggingface_enet_safetensors(fn) + + def test_huggingface_enet_safetensors_fromurl(self): + # test tensor input + t = Tensor.from_url("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors") + self._test_huggingface_enet_safetensors(t) + def test_metadata(self): metadata = {"hello": "world"} safe_save({}, temp('metadata.safetensors'), metadata) diff --git a/test/unit/test_tar.py b/test/unit/test_tar.py index ecff319425..42f00d1695 100644 --- a/test/unit/test_tar.py +++ b/test/unit/test_tar.py @@ -1,8 +1,66 @@ -import unittest, tarfile, io, os, pathlib +import unittest, tarfile, io, os, pathlib, tempfile import numpy as np from tinygrad import Tensor from tinygrad.nn.state import tar_extract +class TestTarExtractFile(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.test_files = { + 'file1.txt': b'Hello, World!', + 'file2.bin': b'\x00\x01\x02\x03\x04', + 'empty_file.txt': b'' + } + self.tar_path = os.path.join(self.test_dir, 'test.tar') + with tarfile.open(self.tar_path, 'w') as tar: + for filename, content in self.test_files.items(): + file_path = os.path.join(self.test_dir, filename) + with open(file_path, 'wb') as f: + f.write(content) + tar.add(file_path, arcname=filename) + + # Create invalid tar file + self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar') + with open(self.invalid_tar_path, 'wb') as f: + f.write(b'This is not a valid tar file') + + def tearDown(self): + for filename in self.test_files: + os.remove(os.path.join(self.test_dir, filename)) + os.remove(self.tar_path) + os.remove(self.invalid_tar_path) + os.rmdir(self.test_dir) + + def test_tar_extract_returns_dict(self): + result = tar_extract(self.tar_path) + self.assertIsInstance(result, dict) + + def test_tar_extract_correct_keys(self): + result = tar_extract(self.tar_path) + self.assertEqual(set(result.keys()), set(self.test_files.keys())) + + def test_tar_extract_content_size(self): + result = tar_extract(self.tar_path) + for filename, content in self.test_files.items(): + self.assertEqual(len(result[filename]), len(content)) + + def test_tar_extract_content_values(self): + result = tar_extract(self.tar_path) + for filename, content in self.test_files.items(): + np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8)) + + def test_tar_extract_empty_file(self): + result = tar_extract(self.tar_path) + self.assertEqual(len(result['empty_file.txt']), 0) + + def test_tar_extract_non_existent_file(self): + with self.assertRaises(FileNotFoundError): + tar_extract('non_existent_file.tar') + + def test_tar_extract_invalid_file(self): + with self.assertRaises(tarfile.ReadError): + tar_extract(self.invalid_tar_path) + class TestTarExtractPAX(unittest.TestCase): tar_format = tarfile.PAX_FORMAT max_link_len = 1000_000 diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index dfd7d9d4be..ac9da2ddc6 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -1,5 +1,5 @@ import os, json, pathlib, zipfile, pickle, tarfile, struct, functools, io -from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable +from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable, TypeVar from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm @@ -35,16 +35,22 @@ safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dt "I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64} inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()} -def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]: +R = TypeVar('R') +def accept_filename(func: Callable[[Tensor], R]) -> Callable[[Union[Tensor, str, pathlib.Path]], R]: + @functools.wraps(func) + def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> R: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn) + return wrapper + +@accept_filename +def safe_load_metadata(t:Tensor) -> Tuple[Tensor, int, Any]: """ Loads a .safetensor file from disk, returning the data, metadata length, and metadata. """ - t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}") json_len = t[0:8].bitcast(dtypes.int64).item() assert isinstance(json_len, int) return t, json_len, json.loads(t[8:8+json_len].data().tobytes()) -def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]: +def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> Dict[str, Tensor]: """ Loads a .safetensor file from disk, returning the state_dict. @@ -157,6 +163,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr else: v.replace(state_dict[k].to(v.device)).realize() if consume: del state_dict[k] +@accept_filename def tar_extract(t: Tensor) -> Dict[str, Tensor]: """ Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values). @@ -287,6 +294,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor: return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales raise ValueError(f"GGML type '{ggml_type}' is not supported!") +@accept_filename def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]: """ Loads a gguf file from a tensor.