tar_extract with Tensors (#7853)

* initial

* USTAR, PAX and GNU support + testing

* from_bytes byteorder

* use TarInfo.frombuf

* tensor only usage

* remove contextlib.suppress

* shorter ow,pax

* more tests

* testing length + move tests

* cleanup

* new approach: RawTensorIO

* fix fetch

* enable read test

* cleanup and ignore fix

* fix for python < 3.12

* make it RawIO

* functions

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
leopf
2024-12-04 10:03:19 +01:00
committed by GitHub
parent 1e06aefde7
commit f0401e14e8
5 changed files with 154 additions and 67 deletions

View File

@@ -1,8 +1,8 @@
import os, pathlib, tempfile, unittest, tarfile
import pathlib, tempfile, unittest
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DType
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import Timing, fetch, temp, CI
from tinygrad.device import is_dtype_supported
@@ -333,63 +333,6 @@ class TestDiskTensor(unittest.TestCase):
on_dev = t.to(Device.DEFAULT).realize()
np.testing.assert_equal(on_dev.numpy(), t.numpy())
class TestTarExtract(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.test_files = {
'file1.txt': b'Hello, World!',
'file2.bin': b'\x00\x01\x02\x03\x04',
'empty_file.txt': b''
}
self.tar_path = os.path.join(self.test_dir, 'test.tar')
with tarfile.open(self.tar_path, 'w') as tar:
for filename, content in self.test_files.items():
file_path = os.path.join(self.test_dir, filename)
with open(file_path, 'wb') as f:
f.write(content)
tar.add(file_path, arcname=filename)
# Create invalid tar file
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
with open(self.invalid_tar_path, 'wb') as f:
f.write(b'This is not a valid tar file')
def tearDown(self):
for filename in self.test_files:
os.remove(os.path.join(self.test_dir, filename))
os.remove(self.tar_path)
os.remove(self.invalid_tar_path)
os.rmdir(self.test_dir)
def test_tar_extract_returns_dict(self):
result = tar_extract(self.tar_path)
self.assertIsInstance(result, dict)
def test_tar_extract_correct_keys(self):
result = tar_extract(self.tar_path)
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
def test_tar_extract_content_size(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
self.assertEqual(len(result[filename]), len(content))
def test_tar_extract_content_values(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
def test_tar_extract_empty_file(self):
result = tar_extract(self.tar_path)
self.assertEqual(len(result['empty_file.txt']), 0)
def test_tar_extract_non_existent_file(self):
with self.assertRaises(FileNotFoundError):
tar_extract('non_existent_file.tar')
def test_tar_extract_invalid_file(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(self.invalid_tar_path)
class TestPathTensor(unittest.TestCase):
def setUp(self):

83
test/unit/test_tar.py Normal file
View File

@@ -0,0 +1,83 @@
import unittest, tarfile, io, os, pathlib
import numpy as np
from tinygrad import Tensor
from tinygrad.nn.state import tar_extract
class TestTarExtractPAX(unittest.TestCase):
tar_format = tarfile.PAX_FORMAT
max_link_len = 1000_000
test_files = {
'a/file1.txt': b'Hello, World!',
'a/b/file2.bin': b'\x00\x01\x02\x03\x04',
'empty_file.txt': b'',
'512file': b'a' * 512,
'long_file': b'some data' * 100,
'very' * 15 + '/' + 'very' * 15 + '_long_filename.txt': b'Hello, World!!',
'very' * 200 + '_long_filename.txt': b'Hello, World!!!',
}
def create_tar_tensor(self):
fobj = io.BytesIO()
test_dirs = set(os.path.dirname(k) for k in self.test_files.keys()).difference({ '' })
with tarfile.open(fileobj=fobj, mode='w', format=self.tar_format) as tar:
for dirname in test_dirs:
dir_info = tarfile.TarInfo(name=dirname)
dir_info.type = tarfile.DIRTYPE
tar.addfile(dir_info)
for filename, content in self.test_files.items():
file_info = tarfile.TarInfo(name=filename)
file_info.size = len(content)
tar.addfile(file_info, io.BytesIO(content))
if len(filename) < self.max_link_len:
link_info = tarfile.TarInfo(name=filename + '.lnk')
link_info.type = tarfile.SYMTYPE
link_info.linkname = filename
tar.addfile(link_info)
return Tensor(fobj.getvalue())
def test_tar_extract_returns_dict(self):
result = tar_extract(self.create_tar_tensor())
self.assertIsInstance(result, dict)
def test_tar_extract_correct_keys(self):
result = tar_extract(self.create_tar_tensor())
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
def test_tar_extract_content_size(self):
result = tar_extract(self.create_tar_tensor())
for filename, content in self.test_files.items():
self.assertEqual(len(result[filename]), len(content))
def test_tar_extract_content_values(self):
result = tar_extract(self.create_tar_tensor())
for filename, content in self.test_files.items():
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
def test_tar_extract_empty_file(self):
result = tar_extract(self.create_tar_tensor())
self.assertEqual(len(result['empty_file.txt']), 0)
def test_tar_extract_non_existent_file(self):
with self.assertRaises(FileNotFoundError):
tar_extract(Tensor(pathlib.Path('non_existent_file.tar')))
def test_tar_extract_invalid_file(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(Tensor(b'This is not a valid tar file'))
def test_tar_extract_invalid_file_long(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(Tensor(b'This is not a valid tar file'*100))
class TestTarExtractUSTAR(TestTarExtractPAX):
tar_format = tarfile.USTAR_FORMAT
max_link_len = 100
test_files = {k: v for k, v in TestTarExtractPAX.test_files.items() if len(k) < 256}
class TestTarExtractGNU(TestTarExtractPAX):
tar_format = tarfile.GNU_FORMAT
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,38 @@
import unittest
from tinygrad import Tensor, dtypes
from tinygrad.nn.state import TensorIO
class TestTensorIO(unittest.TestCase):
def test_create(self):
with self.assertRaises(ValueError):
TensorIO(Tensor(b"Hello World").reshape(1, -1))
with self.assertRaises(ValueError):
TensorIO(Tensor([], dtype=dtypes.int64).reshape(1, -1))
def test_seek(self):
t = Tensor(b"Hello World!")
fobj = TensorIO(t)
self.assertEqual(fobj.tell(), 0)
self.assertEqual(fobj.seek(1), 1)
self.assertEqual(fobj.seek(-2, 2), len(t) - 2)
self.assertEqual(fobj.seek(1, 1), len(t) - 1)
self.assertEqual(fobj.seek(10, 1), len(t))
self.assertEqual(fobj.seek(10, 2), len(t))
self.assertEqual(fobj.seek(-10, 0), 0)
def test_read(self):
data = b"Hello World!"
fobj = TensorIO(Tensor(data))
self.assertEqual(fobj.read(1), data[:1])
self.assertEqual(fobj.read(5), data[1:6])
self.assertEqual(fobj.read(100), data[6:])
self.assertEqual(fobj.read(100), b"")
def test_read_nolen(self):
data = b"Hello World!"
fobj = TensorIO(Tensor(data))
fobj.seek(2)
self.assertEqual(fobj.read(), data[2:])
if __name__ == '__main__':
unittest.main()

View File

@@ -1,5 +1,4 @@
from tinygrad.tensor import Tensor
from tinygrad.helpers import fetch
from tinygrad.nn.state import tar_extract
def mnist(device=None, fashion=False):
@@ -9,7 +8,7 @@ def mnist(device=None, fashion=False):
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)
def cifar(device=None):
tt = tar_extract(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
tt = tar_extract(Tensor.from_url('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
train = Tensor.cat(*[tt[f"cifar-10-batches-bin/data_batch_{i}.bin"].reshape(-1, 3073).to(device) for i in range(1,6)])
test = tt["cifar-10-batches-bin/test_batch.bin"].reshape(-1, 3073).to(device)
return train[:, 1:].reshape(-1,3,32,32), train[:, 0], test[:, 1:].reshape(-1,3,32,32), test[:, 0]

View File

@@ -1,11 +1,36 @@
import os, json, pathlib, zipfile, pickle, tarfile, struct, functools
from typing import Dict, Union, List, Optional, Any, Tuple, Callable
import os, json, pathlib, zipfile, pickle, tarfile, struct, functools, io
from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
from tinygrad.shape.view import strides_for_shape
from tinygrad.multi import MultiLazyBuffer
class TensorIO(io.RawIOBase, BinaryIO):
def __init__(self, t: Tensor):
if len(t.shape) != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
self._position, self._tensor = 0, t
def readable(self) -> bool: return True
def read(self, size: int = -1) -> bytes:
if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens, if readinto returns None (never)
return buf
def readinto(self, buffer: Any) -> int:
data = self._tensor[self._position:self._position+len(buffer)].data()
buffer[:len(data)] = data
self._position += len(data)
return len(data)
def seekable(self) -> bool: return True
def seek(self, offset: int, whence: int = 0) -> int:
self._position = min(len(self._tensor), max(0, [offset, self._position+offset, len(self._tensor)+offset][whence]))
return self._position
# required to correctly implement BinaryIO
def __enter__(self): return self
def write(self, s: Any): raise io.UnsupportedOperation("TensorIO.write not supported")
def writelines(self, lines: Iterable[Any]): raise io.UnsupportedOperation("TensorIO.writelines not supported")
safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
@@ -132,16 +157,15 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
else: v.replace(state_dict[k].to(v.device)).realize()
if consume: del state_dict[k]
def tar_extract(fn:os.PathLike) -> Dict[str, Tensor]:
def tar_extract(t: Tensor) -> Dict[str, Tensor]:
"""
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
```python
tensors = nn.state.tar_extract("archive.tar")
tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
```
"""
t = Tensor(pathlib.Path(fn))
with tarfile.open(fn, "r") as tar:
with tarfile.open(fileobj=TensorIO(t), mode="r") as tar:
return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
# torch support!