mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
tar_extract with Tensors (#7853)
* initial * USTAR, PAX and GNU support + testing * from_bytes byteorder * use TarInfo.frombuf * tensor only usage * remove contextlib.suppress * shorter ow,pax * more tests * testing length + move tests * cleanup * new approach: RawTensorIO * fix fetch * enable read test * cleanup and ignore fix * fix for python < 3.12 * make it RawIO * functions --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import os, pathlib, tempfile, unittest, tarfile
|
||||
import pathlib, tempfile, unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import Timing, fetch, temp, CI
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@@ -333,63 +333,6 @@ class TestDiskTensor(unittest.TestCase):
|
||||
on_dev = t.to(Device.DEFAULT).realize()
|
||||
np.testing.assert_equal(on_dev.numpy(), t.numpy())
|
||||
|
||||
class TestTarExtract(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.test_files = {
|
||||
'file1.txt': b'Hello, World!',
|
||||
'file2.bin': b'\x00\x01\x02\x03\x04',
|
||||
'empty_file.txt': b''
|
||||
}
|
||||
self.tar_path = os.path.join(self.test_dir, 'test.tar')
|
||||
with tarfile.open(self.tar_path, 'w') as tar:
|
||||
for filename, content in self.test_files.items():
|
||||
file_path = os.path.join(self.test_dir, filename)
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
tar.add(file_path, arcname=filename)
|
||||
|
||||
# Create invalid tar file
|
||||
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
|
||||
with open(self.invalid_tar_path, 'wb') as f:
|
||||
f.write(b'This is not a valid tar file')
|
||||
|
||||
def tearDown(self):
|
||||
for filename in self.test_files:
|
||||
os.remove(os.path.join(self.test_dir, filename))
|
||||
os.remove(self.tar_path)
|
||||
os.remove(self.invalid_tar_path)
|
||||
os.rmdir(self.test_dir)
|
||||
|
||||
def test_tar_extract_returns_dict(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertIsInstance(result, dict)
|
||||
|
||||
def test_tar_extract_correct_keys(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
|
||||
|
||||
def test_tar_extract_content_size(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
for filename, content in self.test_files.items():
|
||||
self.assertEqual(len(result[filename]), len(content))
|
||||
|
||||
def test_tar_extract_content_values(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
for filename, content in self.test_files.items():
|
||||
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
|
||||
|
||||
def test_tar_extract_empty_file(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertEqual(len(result['empty_file.txt']), 0)
|
||||
|
||||
def test_tar_extract_non_existent_file(self):
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
tar_extract('non_existent_file.tar')
|
||||
|
||||
def test_tar_extract_invalid_file(self):
|
||||
with self.assertRaises(tarfile.ReadError):
|
||||
tar_extract(self.invalid_tar_path)
|
||||
|
||||
class TestPathTensor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
83
test/unit/test_tar.py
Normal file
83
test/unit/test_tar.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import unittest, tarfile, io, os, pathlib
|
||||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.nn.state import tar_extract
|
||||
|
||||
class TestTarExtractPAX(unittest.TestCase):
|
||||
tar_format = tarfile.PAX_FORMAT
|
||||
max_link_len = 1000_000
|
||||
test_files = {
|
||||
'a/file1.txt': b'Hello, World!',
|
||||
'a/b/file2.bin': b'\x00\x01\x02\x03\x04',
|
||||
'empty_file.txt': b'',
|
||||
'512file': b'a' * 512,
|
||||
'long_file': b'some data' * 100,
|
||||
'very' * 15 + '/' + 'very' * 15 + '_long_filename.txt': b'Hello, World!!',
|
||||
'very' * 200 + '_long_filename.txt': b'Hello, World!!!',
|
||||
}
|
||||
|
||||
def create_tar_tensor(self):
|
||||
fobj = io.BytesIO()
|
||||
test_dirs = set(os.path.dirname(k) for k in self.test_files.keys()).difference({ '' })
|
||||
with tarfile.open(fileobj=fobj, mode='w', format=self.tar_format) as tar:
|
||||
for dirname in test_dirs:
|
||||
dir_info = tarfile.TarInfo(name=dirname)
|
||||
dir_info.type = tarfile.DIRTYPE
|
||||
tar.addfile(dir_info)
|
||||
|
||||
for filename, content in self.test_files.items():
|
||||
file_info = tarfile.TarInfo(name=filename)
|
||||
file_info.size = len(content)
|
||||
tar.addfile(file_info, io.BytesIO(content))
|
||||
|
||||
if len(filename) < self.max_link_len:
|
||||
link_info = tarfile.TarInfo(name=filename + '.lnk')
|
||||
link_info.type = tarfile.SYMTYPE
|
||||
link_info.linkname = filename
|
||||
tar.addfile(link_info)
|
||||
return Tensor(fobj.getvalue())
|
||||
|
||||
def test_tar_extract_returns_dict(self):
|
||||
result = tar_extract(self.create_tar_tensor())
|
||||
self.assertIsInstance(result, dict)
|
||||
|
||||
def test_tar_extract_correct_keys(self):
|
||||
result = tar_extract(self.create_tar_tensor())
|
||||
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
|
||||
|
||||
def test_tar_extract_content_size(self):
|
||||
result = tar_extract(self.create_tar_tensor())
|
||||
for filename, content in self.test_files.items():
|
||||
self.assertEqual(len(result[filename]), len(content))
|
||||
|
||||
def test_tar_extract_content_values(self):
|
||||
result = tar_extract(self.create_tar_tensor())
|
||||
for filename, content in self.test_files.items():
|
||||
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
|
||||
|
||||
def test_tar_extract_empty_file(self):
|
||||
result = tar_extract(self.create_tar_tensor())
|
||||
self.assertEqual(len(result['empty_file.txt']), 0)
|
||||
|
||||
def test_tar_extract_non_existent_file(self):
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
tar_extract(Tensor(pathlib.Path('non_existent_file.tar')))
|
||||
|
||||
def test_tar_extract_invalid_file(self):
|
||||
with self.assertRaises(tarfile.ReadError):
|
||||
tar_extract(Tensor(b'This is not a valid tar file'))
|
||||
|
||||
def test_tar_extract_invalid_file_long(self):
|
||||
with self.assertRaises(tarfile.ReadError):
|
||||
tar_extract(Tensor(b'This is not a valid tar file'*100))
|
||||
|
||||
class TestTarExtractUSTAR(TestTarExtractPAX):
|
||||
tar_format = tarfile.USTAR_FORMAT
|
||||
max_link_len = 100
|
||||
test_files = {k: v for k, v in TestTarExtractPAX.test_files.items() if len(k) < 256}
|
||||
|
||||
class TestTarExtractGNU(TestTarExtractPAX):
|
||||
tar_format = tarfile.GNU_FORMAT
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
38
test/unit/test_tensor_io.py
Normal file
38
test/unit/test_tensor_io.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.nn.state import TensorIO
|
||||
|
||||
class TestTensorIO(unittest.TestCase):
|
||||
def test_create(self):
|
||||
with self.assertRaises(ValueError):
|
||||
TensorIO(Tensor(b"Hello World").reshape(1, -1))
|
||||
with self.assertRaises(ValueError):
|
||||
TensorIO(Tensor([], dtype=dtypes.int64).reshape(1, -1))
|
||||
|
||||
def test_seek(self):
|
||||
t = Tensor(b"Hello World!")
|
||||
fobj = TensorIO(t)
|
||||
self.assertEqual(fobj.tell(), 0)
|
||||
self.assertEqual(fobj.seek(1), 1)
|
||||
self.assertEqual(fobj.seek(-2, 2), len(t) - 2)
|
||||
self.assertEqual(fobj.seek(1, 1), len(t) - 1)
|
||||
self.assertEqual(fobj.seek(10, 1), len(t))
|
||||
self.assertEqual(fobj.seek(10, 2), len(t))
|
||||
self.assertEqual(fobj.seek(-10, 0), 0)
|
||||
|
||||
def test_read(self):
|
||||
data = b"Hello World!"
|
||||
fobj = TensorIO(Tensor(data))
|
||||
self.assertEqual(fobj.read(1), data[:1])
|
||||
self.assertEqual(fobj.read(5), data[1:6])
|
||||
self.assertEqual(fobj.read(100), data[6:])
|
||||
self.assertEqual(fobj.read(100), b"")
|
||||
|
||||
def test_read_nolen(self):
|
||||
data = b"Hello World!"
|
||||
fobj = TensorIO(Tensor(data))
|
||||
fobj.seek(2)
|
||||
self.assertEqual(fobj.read(), data[2:])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,5 +1,4 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import fetch
|
||||
from tinygrad.nn.state import tar_extract
|
||||
|
||||
def mnist(device=None, fashion=False):
|
||||
@@ -9,7 +8,7 @@ def mnist(device=None, fashion=False):
|
||||
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)
|
||||
|
||||
def cifar(device=None):
|
||||
tt = tar_extract(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
|
||||
tt = tar_extract(Tensor.from_url('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
|
||||
train = Tensor.cat(*[tt[f"cifar-10-batches-bin/data_batch_{i}.bin"].reshape(-1, 3073).to(device) for i in range(1,6)])
|
||||
test = tt["cifar-10-batches-bin/test_batch.bin"].reshape(-1, 3073).to(device)
|
||||
return train[:, 1:].reshape(-1,3,32,32), train[:, 0], test[:, 1:].reshape(-1,3,32,32), test[:, 0]
|
||||
|
||||
@@ -1,11 +1,36 @@
|
||||
import os, json, pathlib, zipfile, pickle, tarfile, struct, functools
|
||||
from typing import Dict, Union, List, Optional, Any, Tuple, Callable
|
||||
import os, json, pathlib, zipfile, pickle, tarfile, struct, functools, io
|
||||
from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
|
||||
class TensorIO(io.RawIOBase, BinaryIO):
|
||||
def __init__(self, t: Tensor):
|
||||
if len(t.shape) != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
|
||||
self._position, self._tensor = 0, t
|
||||
|
||||
def readable(self) -> bool: return True
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens, if readinto returns None (never)
|
||||
return buf
|
||||
def readinto(self, buffer: Any) -> int:
|
||||
data = self._tensor[self._position:self._position+len(buffer)].data()
|
||||
buffer[:len(data)] = data
|
||||
self._position += len(data)
|
||||
return len(data)
|
||||
|
||||
def seekable(self) -> bool: return True
|
||||
def seek(self, offset: int, whence: int = 0) -> int:
|
||||
self._position = min(len(self._tensor), max(0, [offset, self._position+offset, len(self._tensor)+offset][whence]))
|
||||
return self._position
|
||||
|
||||
# required to correctly implement BinaryIO
|
||||
def __enter__(self): return self
|
||||
def write(self, s: Any): raise io.UnsupportedOperation("TensorIO.write not supported")
|
||||
def writelines(self, lines: Iterable[Any]): raise io.UnsupportedOperation("TensorIO.writelines not supported")
|
||||
|
||||
safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
|
||||
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
|
||||
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
@@ -132,16 +157,15 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
|
||||
else: v.replace(state_dict[k].to(v.device)).realize()
|
||||
if consume: del state_dict[k]
|
||||
|
||||
def tar_extract(fn:os.PathLike) -> Dict[str, Tensor]:
|
||||
def tar_extract(t: Tensor) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
|
||||
|
||||
```python
|
||||
tensors = nn.state.tar_extract("archive.tar")
|
||||
tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
|
||||
```
|
||||
"""
|
||||
t = Tensor(pathlib.Path(fn))
|
||||
with tarfile.open(fn, "r") as tar:
|
||||
with tarfile.open(fileobj=TensorIO(t), mode="r") as tar:
|
||||
return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
|
||||
|
||||
# torch support!
|
||||
|
||||
Reference in New Issue
Block a user