mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
tar_extract with Tensors (#7853)
* initial * USTAR, PAX and GNU support + testing * from_bytes byteorder * use TarInfo.frombuf * tensor only usage * remove contextlib.suppress * shorter ow,pax * more tests * testing length + move tests * cleanup * new approach: RawTensorIO * fix fetch * enable read test * cleanup and ignore fix * fix for python < 3.12 * make it RawIO * functions --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import os, pathlib, tempfile, unittest, tarfile
|
||||
import pathlib, tempfile, unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import Timing, fetch, temp, CI
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@@ -333,63 +333,6 @@ class TestDiskTensor(unittest.TestCase):
|
||||
on_dev = t.to(Device.DEFAULT).realize()
|
||||
np.testing.assert_equal(on_dev.numpy(), t.numpy())
|
||||
|
||||
class TestTarExtract(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.test_files = {
|
||||
'file1.txt': b'Hello, World!',
|
||||
'file2.bin': b'\x00\x01\x02\x03\x04',
|
||||
'empty_file.txt': b''
|
||||
}
|
||||
self.tar_path = os.path.join(self.test_dir, 'test.tar')
|
||||
with tarfile.open(self.tar_path, 'w') as tar:
|
||||
for filename, content in self.test_files.items():
|
||||
file_path = os.path.join(self.test_dir, filename)
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
tar.add(file_path, arcname=filename)
|
||||
|
||||
# Create invalid tar file
|
||||
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
|
||||
with open(self.invalid_tar_path, 'wb') as f:
|
||||
f.write(b'This is not a valid tar file')
|
||||
|
||||
def tearDown(self):
|
||||
for filename in self.test_files:
|
||||
os.remove(os.path.join(self.test_dir, filename))
|
||||
os.remove(self.tar_path)
|
||||
os.remove(self.invalid_tar_path)
|
||||
os.rmdir(self.test_dir)
|
||||
|
||||
def test_tar_extract_returns_dict(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertIsInstance(result, dict)
|
||||
|
||||
def test_tar_extract_correct_keys(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
|
||||
|
||||
def test_tar_extract_content_size(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
for filename, content in self.test_files.items():
|
||||
self.assertEqual(len(result[filename]), len(content))
|
||||
|
||||
def test_tar_extract_content_values(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
for filename, content in self.test_files.items():
|
||||
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
|
||||
|
||||
def test_tar_extract_empty_file(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertEqual(len(result['empty_file.txt']), 0)
|
||||
|
||||
def test_tar_extract_non_existent_file(self):
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
tar_extract('non_existent_file.tar')
|
||||
|
||||
def test_tar_extract_invalid_file(self):
|
||||
with self.assertRaises(tarfile.ReadError):
|
||||
tar_extract(self.invalid_tar_path)
|
||||
|
||||
class TestPathTensor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
83
test/unit/test_tar.py
Normal file
83
test/unit/test_tar.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import unittest, tarfile, io, os, pathlib
|
||||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.nn.state import tar_extract
|
||||
|
||||
class TestTarExtractPAX(unittest.TestCase):
|
||||
tar_format = tarfile.PAX_FORMAT
|
||||
max_link_len = 1000_000
|
||||
test_files = {
|
||||
'a/file1.txt': b'Hello, World!',
|
||||
'a/b/file2.bin': b'\x00\x01\x02\x03\x04',
|
||||
'empty_file.txt': b'',
|
||||
'512file': b'a' * 512,
|
||||
'long_file': b'some data' * 100,
|
||||
'very' * 15 + '/' + 'very' * 15 + '_long_filename.txt': b'Hello, World!!',
|
||||
'very' * 200 + '_long_filename.txt': b'Hello, World!!!',
|
||||
}
|
||||
|
||||
def create_tar_tensor(self):
|
||||
fobj = io.BytesIO()
|
||||
test_dirs = set(os.path.dirname(k) for k in self.test_files.keys()).difference({ '' })
|
||||
with tarfile.open(fileobj=fobj, mode='w', format=self.tar_format) as tar:
|
||||
for dirname in test_dirs:
|
||||
dir_info = tarfile.TarInfo(name=dirname)
|
||||
dir_info.type = tarfile.DIRTYPE
|
||||
tar.addfile(dir_info)
|
||||
|
||||
for filename, content in self.test_files.items():
|
||||
file_info = tarfile.TarInfo(name=filename)
|
||||
file_info.size = len(content)
|
||||
tar.addfile(file_info, io.BytesIO(content))
|
||||
|
||||
if len(filename) < self.max_link_len:
|
||||
link_info = tarfile.TarInfo(name=filename + '.lnk')
|
||||
link_info.type = tarfile.SYMTYPE
|
||||
link_info.linkname = filename
|
||||
tar.addfile(link_info)
|
||||
return Tensor(fobj.getvalue())
|
||||
|
||||
def test_tar_extract_returns_dict(self):
|
||||
result = tar_extract(self.create_tar_tensor())
|
||||
self.assertIsInstance(result, dict)
|
||||
|
||||
def test_tar_extract_correct_keys(self):
|
||||
result = tar_extract(self.create_tar_tensor())
|
||||
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
|
||||
|
||||
def test_tar_extract_content_size(self):
|
||||
result = tar_extract(self.create_tar_tensor())
|
||||
for filename, content in self.test_files.items():
|
||||
self.assertEqual(len(result[filename]), len(content))
|
||||
|
||||
def test_tar_extract_content_values(self):
|
||||
result = tar_extract(self.create_tar_tensor())
|
||||
for filename, content in self.test_files.items():
|
||||
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
|
||||
|
||||
def test_tar_extract_empty_file(self):
|
||||
result = tar_extract(self.create_tar_tensor())
|
||||
self.assertEqual(len(result['empty_file.txt']), 0)
|
||||
|
||||
def test_tar_extract_non_existent_file(self):
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
tar_extract(Tensor(pathlib.Path('non_existent_file.tar')))
|
||||
|
||||
def test_tar_extract_invalid_file(self):
|
||||
with self.assertRaises(tarfile.ReadError):
|
||||
tar_extract(Tensor(b'This is not a valid tar file'))
|
||||
|
||||
def test_tar_extract_invalid_file_long(self):
|
||||
with self.assertRaises(tarfile.ReadError):
|
||||
tar_extract(Tensor(b'This is not a valid tar file'*100))
|
||||
|
||||
class TestTarExtractUSTAR(TestTarExtractPAX):
|
||||
tar_format = tarfile.USTAR_FORMAT
|
||||
max_link_len = 100
|
||||
test_files = {k: v for k, v in TestTarExtractPAX.test_files.items() if len(k) < 256}
|
||||
|
||||
class TestTarExtractGNU(TestTarExtractPAX):
|
||||
tar_format = tarfile.GNU_FORMAT
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
38
test/unit/test_tensor_io.py
Normal file
38
test/unit/test_tensor_io.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.nn.state import TensorIO
|
||||
|
||||
class TestTensorIO(unittest.TestCase):
|
||||
def test_create(self):
|
||||
with self.assertRaises(ValueError):
|
||||
TensorIO(Tensor(b"Hello World").reshape(1, -1))
|
||||
with self.assertRaises(ValueError):
|
||||
TensorIO(Tensor([], dtype=dtypes.int64).reshape(1, -1))
|
||||
|
||||
def test_seek(self):
|
||||
t = Tensor(b"Hello World!")
|
||||
fobj = TensorIO(t)
|
||||
self.assertEqual(fobj.tell(), 0)
|
||||
self.assertEqual(fobj.seek(1), 1)
|
||||
self.assertEqual(fobj.seek(-2, 2), len(t) - 2)
|
||||
self.assertEqual(fobj.seek(1, 1), len(t) - 1)
|
||||
self.assertEqual(fobj.seek(10, 1), len(t))
|
||||
self.assertEqual(fobj.seek(10, 2), len(t))
|
||||
self.assertEqual(fobj.seek(-10, 0), 0)
|
||||
|
||||
def test_read(self):
|
||||
data = b"Hello World!"
|
||||
fobj = TensorIO(Tensor(data))
|
||||
self.assertEqual(fobj.read(1), data[:1])
|
||||
self.assertEqual(fobj.read(5), data[1:6])
|
||||
self.assertEqual(fobj.read(100), data[6:])
|
||||
self.assertEqual(fobj.read(100), b"")
|
||||
|
||||
def test_read_nolen(self):
|
||||
data = b"Hello World!"
|
||||
fobj = TensorIO(Tensor(data))
|
||||
fobj.seek(2)
|
||||
self.assertEqual(fobj.read(), data[2:])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user