mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Write tar_extract (#6180)
* Add tar_extract * Add tar_extract tests * Fix dtype for initialization from path * Tests for path initialization * rm print --------- Co-authored-by: Maximilian Weichart <maximilian.weichart@icloud.com>
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import os
|
||||
import pathlib, tempfile, unittest
|
||||
import tarfile
|
||||
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract
|
||||
from tinygrad.helpers import Timing, fetch, temp, CI
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
@@ -333,5 +336,114 @@ class TestDiskTensor(unittest.TestCase):
|
||||
on_dev = t.to(Device.DEFAULT).realize()
|
||||
np.testing.assert_equal(on_dev.numpy(), t.numpy())
|
||||
|
||||
class TestTarExtract(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.test_files = {
|
||||
'file1.txt': b'Hello, World!',
|
||||
'file2.bin': b'\x00\x01\x02\x03\x04',
|
||||
'empty_file.txt': b''
|
||||
}
|
||||
self.tar_path = os.path.join(self.test_dir, 'test.tar')
|
||||
with tarfile.open(self.tar_path, 'w') as tar:
|
||||
for filename, content in self.test_files.items():
|
||||
file_path = os.path.join(self.test_dir, filename)
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
tar.add(file_path, arcname=filename)
|
||||
|
||||
# Create invalid tar file
|
||||
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
|
||||
with open(self.invalid_tar_path, 'wb') as f:
|
||||
f.write(b'This is not a valid tar file')
|
||||
|
||||
def tearDown(self):
|
||||
for filename in self.test_files:
|
||||
os.remove(os.path.join(self.test_dir, filename))
|
||||
os.remove(self.tar_path)
|
||||
os.remove(self.invalid_tar_path)
|
||||
os.rmdir(self.test_dir)
|
||||
|
||||
def test_tar_extract_returns_dict(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertIsInstance(result, dict)
|
||||
|
||||
def test_tar_extract_correct_keys(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
|
||||
|
||||
def test_tar_extract_content_size(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
for filename, content in self.test_files.items():
|
||||
self.assertEqual(len(result[filename]), len(content))
|
||||
|
||||
def test_tar_extract_content_values(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
for filename, content in self.test_files.items():
|
||||
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
|
||||
|
||||
def test_tar_extract_empty_file(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertEqual(len(result['empty_file.txt']), 0)
|
||||
|
||||
def test_tar_extract_non_existent_file(self):
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
tar_extract('non_existent_file.tar')
|
||||
|
||||
def test_tar_extract_invalid_file(self):
|
||||
with self.assertRaises(tarfile.ReadError):
|
||||
tar_extract(self.invalid_tar_path)
|
||||
|
||||
class TestPathTensor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.test_file = pathlib.Path(self.temp_dir.name) / "test_file.bin"
|
||||
self.test_data = np.arange(100, dtype=np.uint8).tobytes()
|
||||
with open(self.test_file, "wb") as f:
|
||||
f.write(self.test_data)
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_path_tensor_no_device(self):
|
||||
t = Tensor(self.test_file)
|
||||
self.assertEqual(t.shape, (100,))
|
||||
self.assertEqual(t.dtype, dtypes.uint8)
|
||||
self.assertTrue(t.device.startswith("DISK:"))
|
||||
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
def test_path_tensor_with_device(self):
|
||||
t = Tensor(self.test_file, device="CPU")
|
||||
self.assertEqual(t.shape, (100,))
|
||||
self.assertEqual(t.dtype, dtypes.uint8)
|
||||
self.assertEqual(t.device, "CPU")
|
||||
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
def test_path_tensor_empty_file(self):
|
||||
empty_file = pathlib.Path(self.temp_dir.name) / "empty_file.bin"
|
||||
empty_file.touch()
|
||||
t = Tensor(empty_file)
|
||||
self.assertEqual(t.shape, (0,))
|
||||
self.assertEqual(t.dtype, dtypes.uint8)
|
||||
self.assertTrue(t.device.startswith("DISK:"))
|
||||
|
||||
def test_path_tensor_non_existent_file(self):
|
||||
non_existent_file = pathlib.Path(self.temp_dir.name) / "non_existent.bin"
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
Tensor(non_existent_file)
|
||||
|
||||
def test_path_tensor_with_dtype(self):
|
||||
t = Tensor(self.test_file, dtype=dtypes.int16)
|
||||
self.assertEqual(t.shape, (50,))
|
||||
self.assertEqual(t.dtype, dtypes.int16)
|
||||
self.assertTrue(t.device.startswith("DISK:"))
|
||||
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.int16))
|
||||
|
||||
def test_path_tensor_copy_to_device(self):
|
||||
t = Tensor(self.test_file)
|
||||
t_cpu = t.to("CPU")
|
||||
self.assertEqual(t_cpu.device, "CPU")
|
||||
np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user