diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 7d1b249f4a..7c1d239f16 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -1,8 +1,8 @@ -import os, pathlib, tempfile, unittest, tarfile +import pathlib, tempfile, unittest import numpy as np from tinygrad import Tensor, Device, dtypes from tinygrad.dtype import DType -from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract +from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load from tinygrad.helpers import Timing, fetch, temp, CI from tinygrad.device import is_dtype_supported @@ -333,63 +333,6 @@ class TestDiskTensor(unittest.TestCase): on_dev = t.to(Device.DEFAULT).realize() np.testing.assert_equal(on_dev.numpy(), t.numpy()) -class TestTarExtract(unittest.TestCase): - def setUp(self): - self.test_dir = tempfile.mkdtemp() - self.test_files = { - 'file1.txt': b'Hello, World!', - 'file2.bin': b'\x00\x01\x02\x03\x04', - 'empty_file.txt': b'' - } - self.tar_path = os.path.join(self.test_dir, 'test.tar') - with tarfile.open(self.tar_path, 'w') as tar: - for filename, content in self.test_files.items(): - file_path = os.path.join(self.test_dir, filename) - with open(file_path, 'wb') as f: - f.write(content) - tar.add(file_path, arcname=filename) - - # Create invalid tar file - self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar') - with open(self.invalid_tar_path, 'wb') as f: - f.write(b'This is not a valid tar file') - - def tearDown(self): - for filename in self.test_files: - os.remove(os.path.join(self.test_dir, filename)) - os.remove(self.tar_path) - os.remove(self.invalid_tar_path) - os.rmdir(self.test_dir) - - def test_tar_extract_returns_dict(self): - result = tar_extract(self.tar_path) - self.assertIsInstance(result, dict) - - def test_tar_extract_correct_keys(self): - result = tar_extract(self.tar_path) - self.assertEqual(set(result.keys()), set(self.test_files.keys())) - - def test_tar_extract_content_size(self): - result = tar_extract(self.tar_path) - for filename, content in self.test_files.items(): - self.assertEqual(len(result[filename]), len(content)) - - def test_tar_extract_content_values(self): - result = tar_extract(self.tar_path) - for filename, content in self.test_files.items(): - np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8)) - - def test_tar_extract_empty_file(self): - result = tar_extract(self.tar_path) - self.assertEqual(len(result['empty_file.txt']), 0) - - def test_tar_extract_non_existent_file(self): - with self.assertRaises(FileNotFoundError): - tar_extract('non_existent_file.tar') - - def test_tar_extract_invalid_file(self): - with self.assertRaises(tarfile.ReadError): - tar_extract(self.invalid_tar_path) class TestPathTensor(unittest.TestCase): def setUp(self): diff --git a/test/unit/test_tar.py b/test/unit/test_tar.py new file mode 100644 index 0000000000..ecff319425 --- /dev/null +++ b/test/unit/test_tar.py @@ -0,0 +1,83 @@ +import unittest, tarfile, io, os, pathlib +import numpy as np +from tinygrad import Tensor +from tinygrad.nn.state import tar_extract + +class TestTarExtractPAX(unittest.TestCase): + tar_format = tarfile.PAX_FORMAT + max_link_len = 1000_000 + test_files = { + 'a/file1.txt': b'Hello, World!', + 'a/b/file2.bin': b'\x00\x01\x02\x03\x04', + 'empty_file.txt': b'', + '512file': b'a' * 512, + 'long_file': b'some data' * 100, + 'very' * 15 + '/' + 'very' * 15 + '_long_filename.txt': b'Hello, World!!', + 'very' * 200 + '_long_filename.txt': b'Hello, World!!!', + } + + def create_tar_tensor(self): + fobj = io.BytesIO() + test_dirs = set(os.path.dirname(k) for k in self.test_files.keys()).difference({ '' }) + with tarfile.open(fileobj=fobj, mode='w', format=self.tar_format) as tar: + for dirname in test_dirs: + dir_info = tarfile.TarInfo(name=dirname) + dir_info.type = tarfile.DIRTYPE + tar.addfile(dir_info) + + for filename, content in self.test_files.items(): + file_info = tarfile.TarInfo(name=filename) + file_info.size = len(content) + tar.addfile(file_info, io.BytesIO(content)) + + if len(filename) < self.max_link_len: + link_info = tarfile.TarInfo(name=filename + '.lnk') + link_info.type = tarfile.SYMTYPE + link_info.linkname = filename + tar.addfile(link_info) + return Tensor(fobj.getvalue()) + + def test_tar_extract_returns_dict(self): + result = tar_extract(self.create_tar_tensor()) + self.assertIsInstance(result, dict) + + def test_tar_extract_correct_keys(self): + result = tar_extract(self.create_tar_tensor()) + self.assertEqual(set(result.keys()), set(self.test_files.keys())) + + def test_tar_extract_content_size(self): + result = tar_extract(self.create_tar_tensor()) + for filename, content in self.test_files.items(): + self.assertEqual(len(result[filename]), len(content)) + + def test_tar_extract_content_values(self): + result = tar_extract(self.create_tar_tensor()) + for filename, content in self.test_files.items(): + np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8)) + + def test_tar_extract_empty_file(self): + result = tar_extract(self.create_tar_tensor()) + self.assertEqual(len(result['empty_file.txt']), 0) + + def test_tar_extract_non_existent_file(self): + with self.assertRaises(FileNotFoundError): + tar_extract(Tensor(pathlib.Path('non_existent_file.tar'))) + + def test_tar_extract_invalid_file(self): + with self.assertRaises(tarfile.ReadError): + tar_extract(Tensor(b'This is not a valid tar file')) + + def test_tar_extract_invalid_file_long(self): + with self.assertRaises(tarfile.ReadError): + tar_extract(Tensor(b'This is not a valid tar file'*100)) + +class TestTarExtractUSTAR(TestTarExtractPAX): + tar_format = tarfile.USTAR_FORMAT + max_link_len = 100 + test_files = {k: v for k, v in TestTarExtractPAX.test_files.items() if len(k) < 256} + +class TestTarExtractGNU(TestTarExtractPAX): + tar_format = tarfile.GNU_FORMAT + +if __name__ == '__main__': + unittest.main() diff --git a/test/unit/test_tensor_io.py b/test/unit/test_tensor_io.py new file mode 100644 index 0000000000..a839257289 --- /dev/null +++ b/test/unit/test_tensor_io.py @@ -0,0 +1,38 @@ +import unittest +from tinygrad import Tensor, dtypes +from tinygrad.nn.state import TensorIO + +class TestTensorIO(unittest.TestCase): + def test_create(self): + with self.assertRaises(ValueError): + TensorIO(Tensor(b"Hello World").reshape(1, -1)) + with self.assertRaises(ValueError): + TensorIO(Tensor([], dtype=dtypes.int64).reshape(1, -1)) + + def test_seek(self): + t = Tensor(b"Hello World!") + fobj = TensorIO(t) + self.assertEqual(fobj.tell(), 0) + self.assertEqual(fobj.seek(1), 1) + self.assertEqual(fobj.seek(-2, 2), len(t) - 2) + self.assertEqual(fobj.seek(1, 1), len(t) - 1) + self.assertEqual(fobj.seek(10, 1), len(t)) + self.assertEqual(fobj.seek(10, 2), len(t)) + self.assertEqual(fobj.seek(-10, 0), 0) + + def test_read(self): + data = b"Hello World!" + fobj = TensorIO(Tensor(data)) + self.assertEqual(fobj.read(1), data[:1]) + self.assertEqual(fobj.read(5), data[1:6]) + self.assertEqual(fobj.read(100), data[6:]) + self.assertEqual(fobj.read(100), b"") + + def test_read_nolen(self): + data = b"Hello World!" + fobj = TensorIO(Tensor(data)) + fobj.seek(2) + self.assertEqual(fobj.read(), data[2:]) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/nn/datasets.py b/tinygrad/nn/datasets.py index 31d560373a..30439fb49b 100644 --- a/tinygrad/nn/datasets.py +++ b/tinygrad/nn/datasets.py @@ -1,5 +1,4 @@ from tinygrad.tensor import Tensor -from tinygrad.helpers import fetch from tinygrad.nn.state import tar_extract def mnist(device=None, fashion=False): @@ -9,7 +8,7 @@ def mnist(device=None, fashion=False): _mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device) def cifar(device=None): - tt = tar_extract(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True)) + tt = tar_extract(Tensor.from_url('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True)) train = Tensor.cat(*[tt[f"cifar-10-batches-bin/data_batch_{i}.bin"].reshape(-1, 3073).to(device) for i in range(1,6)]) test = tt["cifar-10-batches-bin/test_batch.bin"].reshape(-1, 3073).to(device) return train[:, 1:].reshape(-1,3,32,32), train[:, 0], test[:, 1:].reshape(-1,3,32,32), test[:, 0] diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 83d54bdfcd..e29df2dfd5 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -1,11 +1,36 @@ -import os, json, pathlib, zipfile, pickle, tarfile, struct, functools -from typing import Dict, Union, List, Optional, Any, Tuple, Callable +import os, json, pathlib, zipfile, pickle, tarfile, struct, functools, io +from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm from tinygrad.shape.view import strides_for_shape from tinygrad.multi import MultiLazyBuffer +class TensorIO(io.RawIOBase, BinaryIO): + def __init__(self, t: Tensor): + if len(t.shape) != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!") + self._position, self._tensor = 0, t + + def readable(self) -> bool: return True + def read(self, size: int = -1) -> bytes: + if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens, if readinto returns None (never) + return buf + def readinto(self, buffer: Any) -> int: + data = self._tensor[self._position:self._position+len(buffer)].data() + buffer[:len(data)] = data + self._position += len(data) + return len(data) + + def seekable(self) -> bool: return True + def seek(self, offset: int, whence: int = 0) -> int: + self._position = min(len(self._tensor), max(0, [offset, self._position+offset, len(self._tensor)+offset][whence])) + return self._position + + # required to correctly implement BinaryIO + def __enter__(self): return self + def write(self, s: Any): raise io.UnsupportedOperation("TensorIO.write not supported") + def writelines(self, lines: Iterable[Any]): raise io.UnsupportedOperation("TensorIO.writelines not supported") + safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint, "I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64} inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()} @@ -132,16 +157,15 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr else: v.replace(state_dict[k].to(v.device)).realize() if consume: del state_dict[k] -def tar_extract(fn:os.PathLike) -> Dict[str, Tensor]: +def tar_extract(t: Tensor) -> Dict[str, Tensor]: """ Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values). ```python - tensors = nn.state.tar_extract("archive.tar") + tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar"))) ``` """ - t = Tensor(pathlib.Path(fn)) - with tarfile.open(fn, "r") as tar: + with tarfile.open(fileobj=TensorIO(t), mode="r") as tar: return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE} # torch support!