Write tar_extract (#6180)

* Add tar_extract

* Add tar_extract tests

* Fix dtype for initialization from path

* Tests for path initialization

* rm print

---------

Co-authored-by: Maximilian Weichart <maximilian.weichart@icloud.com>
This commit is contained in:
Max-We
2024-08-19 21:06:17 +02:00
committed by GitHub
parent 8556d0c642
commit 53b20afa3f
3 changed files with 131 additions and 3 deletions

View File

@@ -1,8 +1,11 @@
import os
import pathlib, tempfile, unittest
import tarfile
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DType
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract
from tinygrad.helpers import Timing, fetch, temp, CI
from test.helpers import is_dtype_supported
@@ -333,5 +336,114 @@ class TestDiskTensor(unittest.TestCase):
on_dev = t.to(Device.DEFAULT).realize()
np.testing.assert_equal(on_dev.numpy(), t.numpy())
class TestTarExtract(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.test_files = {
'file1.txt': b'Hello, World!',
'file2.bin': b'\x00\x01\x02\x03\x04',
'empty_file.txt': b''
}
self.tar_path = os.path.join(self.test_dir, 'test.tar')
with tarfile.open(self.tar_path, 'w') as tar:
for filename, content in self.test_files.items():
file_path = os.path.join(self.test_dir, filename)
with open(file_path, 'wb') as f:
f.write(content)
tar.add(file_path, arcname=filename)
# Create invalid tar file
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
with open(self.invalid_tar_path, 'wb') as f:
f.write(b'This is not a valid tar file')
def tearDown(self):
for filename in self.test_files:
os.remove(os.path.join(self.test_dir, filename))
os.remove(self.tar_path)
os.remove(self.invalid_tar_path)
os.rmdir(self.test_dir)
def test_tar_extract_returns_dict(self):
result = tar_extract(self.tar_path)
self.assertIsInstance(result, dict)
def test_tar_extract_correct_keys(self):
result = tar_extract(self.tar_path)
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
def test_tar_extract_content_size(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
self.assertEqual(len(result[filename]), len(content))
def test_tar_extract_content_values(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
def test_tar_extract_empty_file(self):
result = tar_extract(self.tar_path)
self.assertEqual(len(result['empty_file.txt']), 0)
def test_tar_extract_non_existent_file(self):
with self.assertRaises(FileNotFoundError):
tar_extract('non_existent_file.tar')
def test_tar_extract_invalid_file(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(self.invalid_tar_path)
class TestPathTensor(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.test_file = pathlib.Path(self.temp_dir.name) / "test_file.bin"
self.test_data = np.arange(100, dtype=np.uint8).tobytes()
with open(self.test_file, "wb") as f:
f.write(self.test_data)
def tearDown(self):
self.temp_dir.cleanup()
def test_path_tensor_no_device(self):
t = Tensor(self.test_file)
self.assertEqual(t.shape, (100,))
self.assertEqual(t.dtype, dtypes.uint8)
self.assertTrue(t.device.startswith("DISK:"))
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
def test_path_tensor_with_device(self):
t = Tensor(self.test_file, device="CPU")
self.assertEqual(t.shape, (100,))
self.assertEqual(t.dtype, dtypes.uint8)
self.assertEqual(t.device, "CPU")
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
def test_path_tensor_empty_file(self):
empty_file = pathlib.Path(self.temp_dir.name) / "empty_file.bin"
empty_file.touch()
t = Tensor(empty_file)
self.assertEqual(t.shape, (0,))
self.assertEqual(t.dtype, dtypes.uint8)
self.assertTrue(t.device.startswith("DISK:"))
def test_path_tensor_non_existent_file(self):
non_existent_file = pathlib.Path(self.temp_dir.name) / "non_existent.bin"
with self.assertRaises(FileNotFoundError):
Tensor(non_existent_file)
def test_path_tensor_with_dtype(self):
t = Tensor(self.test_file, dtype=dtypes.int16)
self.assertEqual(t.shape, (50,))
self.assertEqual(t.dtype, dtypes.int16)
self.assertTrue(t.device.startswith("DISK:"))
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.int16))
def test_path_tensor_copy_to_device(self):
t = Tensor(self.test_file)
t_cpu = t.to("CPU")
self.assertEqual(t_cpu.device, "CPU")
np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
if __name__ == "__main__":
unittest.main()