tar_extract with Tensors (#7853)

* initial

* USTAR, PAX and GNU support + testing

* from_bytes byteorder

* use TarInfo.frombuf

* tensor only usage

* remove contextlib.suppress

* shorter ow,pax

* more tests

* testing length + move tests

* cleanup

* new approach: RawTensorIO

* fix fetch

* enable read test

* cleanup and ignore fix

* fix for python < 3.12

* make it RawIO

* functions

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
leopf
2024-12-04 10:03:19 +01:00
committed by GitHub
parent 1e06aefde7
commit f0401e14e8
5 changed files with 154 additions and 67 deletions

View File

@@ -1,8 +1,8 @@
import os, pathlib, tempfile, unittest, tarfile
import pathlib, tempfile, unittest
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DType
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import Timing, fetch, temp, CI
from tinygrad.device import is_dtype_supported
@@ -333,63 +333,6 @@ class TestDiskTensor(unittest.TestCase):
on_dev = t.to(Device.DEFAULT).realize()
np.testing.assert_equal(on_dev.numpy(), t.numpy())
class TestTarExtract(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.test_files = {
'file1.txt': b'Hello, World!',
'file2.bin': b'\x00\x01\x02\x03\x04',
'empty_file.txt': b''
}
self.tar_path = os.path.join(self.test_dir, 'test.tar')
with tarfile.open(self.tar_path, 'w') as tar:
for filename, content in self.test_files.items():
file_path = os.path.join(self.test_dir, filename)
with open(file_path, 'wb') as f:
f.write(content)
tar.add(file_path, arcname=filename)
# Create invalid tar file
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
with open(self.invalid_tar_path, 'wb') as f:
f.write(b'This is not a valid tar file')
def tearDown(self):
for filename in self.test_files:
os.remove(os.path.join(self.test_dir, filename))
os.remove(self.tar_path)
os.remove(self.invalid_tar_path)
os.rmdir(self.test_dir)
def test_tar_extract_returns_dict(self):
result = tar_extract(self.tar_path)
self.assertIsInstance(result, dict)
def test_tar_extract_correct_keys(self):
result = tar_extract(self.tar_path)
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
def test_tar_extract_content_size(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
self.assertEqual(len(result[filename]), len(content))
def test_tar_extract_content_values(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
def test_tar_extract_empty_file(self):
result = tar_extract(self.tar_path)
self.assertEqual(len(result['empty_file.txt']), 0)
def test_tar_extract_non_existent_file(self):
with self.assertRaises(FileNotFoundError):
tar_extract('non_existent_file.tar')
def test_tar_extract_invalid_file(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(self.invalid_tar_path)
class TestPathTensor(unittest.TestCase):
def setUp(self):

83
test/unit/test_tar.py Normal file
View File

@@ -0,0 +1,83 @@
import unittest, tarfile, io, os, pathlib
import numpy as np
from tinygrad import Tensor
from tinygrad.nn.state import tar_extract
class TestTarExtractPAX(unittest.TestCase):
tar_format = tarfile.PAX_FORMAT
max_link_len = 1000_000
test_files = {
'a/file1.txt': b'Hello, World!',
'a/b/file2.bin': b'\x00\x01\x02\x03\x04',
'empty_file.txt': b'',
'512file': b'a' * 512,
'long_file': b'some data' * 100,
'very' * 15 + '/' + 'very' * 15 + '_long_filename.txt': b'Hello, World!!',
'very' * 200 + '_long_filename.txt': b'Hello, World!!!',
}
def create_tar_tensor(self):
fobj = io.BytesIO()
test_dirs = set(os.path.dirname(k) for k in self.test_files.keys()).difference({ '' })
with tarfile.open(fileobj=fobj, mode='w', format=self.tar_format) as tar:
for dirname in test_dirs:
dir_info = tarfile.TarInfo(name=dirname)
dir_info.type = tarfile.DIRTYPE
tar.addfile(dir_info)
for filename, content in self.test_files.items():
file_info = tarfile.TarInfo(name=filename)
file_info.size = len(content)
tar.addfile(file_info, io.BytesIO(content))
if len(filename) < self.max_link_len:
link_info = tarfile.TarInfo(name=filename + '.lnk')
link_info.type = tarfile.SYMTYPE
link_info.linkname = filename
tar.addfile(link_info)
return Tensor(fobj.getvalue())
def test_tar_extract_returns_dict(self):
result = tar_extract(self.create_tar_tensor())
self.assertIsInstance(result, dict)
def test_tar_extract_correct_keys(self):
result = tar_extract(self.create_tar_tensor())
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
def test_tar_extract_content_size(self):
result = tar_extract(self.create_tar_tensor())
for filename, content in self.test_files.items():
self.assertEqual(len(result[filename]), len(content))
def test_tar_extract_content_values(self):
result = tar_extract(self.create_tar_tensor())
for filename, content in self.test_files.items():
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
def test_tar_extract_empty_file(self):
result = tar_extract(self.create_tar_tensor())
self.assertEqual(len(result['empty_file.txt']), 0)
def test_tar_extract_non_existent_file(self):
with self.assertRaises(FileNotFoundError):
tar_extract(Tensor(pathlib.Path('non_existent_file.tar')))
def test_tar_extract_invalid_file(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(Tensor(b'This is not a valid tar file'))
def test_tar_extract_invalid_file_long(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(Tensor(b'This is not a valid tar file'*100))
class TestTarExtractUSTAR(TestTarExtractPAX):
tar_format = tarfile.USTAR_FORMAT
max_link_len = 100
test_files = {k: v for k, v in TestTarExtractPAX.test_files.items() if len(k) < 256}
class TestTarExtractGNU(TestTarExtractPAX):
tar_format = tarfile.GNU_FORMAT
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,38 @@
import unittest
from tinygrad import Tensor, dtypes
from tinygrad.nn.state import TensorIO
class TestTensorIO(unittest.TestCase):
def test_create(self):
with self.assertRaises(ValueError):
TensorIO(Tensor(b"Hello World").reshape(1, -1))
with self.assertRaises(ValueError):
TensorIO(Tensor([], dtype=dtypes.int64).reshape(1, -1))
def test_seek(self):
t = Tensor(b"Hello World!")
fobj = TensorIO(t)
self.assertEqual(fobj.tell(), 0)
self.assertEqual(fobj.seek(1), 1)
self.assertEqual(fobj.seek(-2, 2), len(t) - 2)
self.assertEqual(fobj.seek(1, 1), len(t) - 1)
self.assertEqual(fobj.seek(10, 1), len(t))
self.assertEqual(fobj.seek(10, 2), len(t))
self.assertEqual(fobj.seek(-10, 0), 0)
def test_read(self):
data = b"Hello World!"
fobj = TensorIO(Tensor(data))
self.assertEqual(fobj.read(1), data[:1])
self.assertEqual(fobj.read(5), data[1:6])
self.assertEqual(fobj.read(100), data[6:])
self.assertEqual(fobj.read(100), b"")
def test_read_nolen(self):
data = b"Hello World!"
fobj = TensorIO(Tensor(data))
fobj.seek(2)
self.assertEqual(fobj.read(), data[2:])
if __name__ == '__main__':
unittest.main()