Write tar_extract (#6180)

* Add tar_extract

* Add tar_extract tests

* Fix dtype for initialization from path

* Tests for path initialization

* rm print

---------

Co-authored-by: Maximilian Weichart <maximilian.weichart@icloud.com>
This commit is contained in:
Max-We
2024-08-19 21:06:17 +02:00
committed by GitHub
parent 8556d0c642
commit 53b20afa3f
3 changed files with 131 additions and 3 deletions

View File

@@ -1,8 +1,11 @@
import os
import pathlib, tempfile, unittest
import tarfile
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DType
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract
from tinygrad.helpers import Timing, fetch, temp, CI
from test.helpers import is_dtype_supported
@@ -333,5 +336,114 @@ class TestDiskTensor(unittest.TestCase):
on_dev = t.to(Device.DEFAULT).realize()
np.testing.assert_equal(on_dev.numpy(), t.numpy())
class TestTarExtract(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.test_files = {
'file1.txt': b'Hello, World!',
'file2.bin': b'\x00\x01\x02\x03\x04',
'empty_file.txt': b''
}
self.tar_path = os.path.join(self.test_dir, 'test.tar')
with tarfile.open(self.tar_path, 'w') as tar:
for filename, content in self.test_files.items():
file_path = os.path.join(self.test_dir, filename)
with open(file_path, 'wb') as f:
f.write(content)
tar.add(file_path, arcname=filename)
# Create invalid tar file
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
with open(self.invalid_tar_path, 'wb') as f:
f.write(b'This is not a valid tar file')
def tearDown(self):
for filename in self.test_files:
os.remove(os.path.join(self.test_dir, filename))
os.remove(self.tar_path)
os.remove(self.invalid_tar_path)
os.rmdir(self.test_dir)
def test_tar_extract_returns_dict(self):
result = tar_extract(self.tar_path)
self.assertIsInstance(result, dict)
def test_tar_extract_correct_keys(self):
result = tar_extract(self.tar_path)
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
def test_tar_extract_content_size(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
self.assertEqual(len(result[filename]), len(content))
def test_tar_extract_content_values(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
def test_tar_extract_empty_file(self):
result = tar_extract(self.tar_path)
self.assertEqual(len(result['empty_file.txt']), 0)
def test_tar_extract_non_existent_file(self):
with self.assertRaises(FileNotFoundError):
tar_extract('non_existent_file.tar')
def test_tar_extract_invalid_file(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(self.invalid_tar_path)
class TestPathTensor(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.test_file = pathlib.Path(self.temp_dir.name) / "test_file.bin"
self.test_data = np.arange(100, dtype=np.uint8).tobytes()
with open(self.test_file, "wb") as f:
f.write(self.test_data)
def tearDown(self):
self.temp_dir.cleanup()
def test_path_tensor_no_device(self):
t = Tensor(self.test_file)
self.assertEqual(t.shape, (100,))
self.assertEqual(t.dtype, dtypes.uint8)
self.assertTrue(t.device.startswith("DISK:"))
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
def test_path_tensor_with_device(self):
t = Tensor(self.test_file, device="CPU")
self.assertEqual(t.shape, (100,))
self.assertEqual(t.dtype, dtypes.uint8)
self.assertEqual(t.device, "CPU")
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
def test_path_tensor_empty_file(self):
empty_file = pathlib.Path(self.temp_dir.name) / "empty_file.bin"
empty_file.touch()
t = Tensor(empty_file)
self.assertEqual(t.shape, (0,))
self.assertEqual(t.dtype, dtypes.uint8)
self.assertTrue(t.device.startswith("DISK:"))
def test_path_tensor_non_existent_file(self):
non_existent_file = pathlib.Path(self.temp_dir.name) / "non_existent.bin"
with self.assertRaises(FileNotFoundError):
Tensor(non_existent_file)
def test_path_tensor_with_dtype(self):
t = Tensor(self.test_file, dtype=dtypes.int16)
self.assertEqual(t.shape, (50,))
self.assertEqual(t.dtype, dtypes.int16)
self.assertTrue(t.device.startswith("DISK:"))
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.int16))
def test_path_tensor_copy_to_device(self):
t = Tensor(self.test_file)
t_cpu = t.to("CPU")
self.assertEqual(t_cpu.device, "CPU")
np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
if __name__ == "__main__":
unittest.main()

View File

@@ -129,6 +129,18 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
else: v.replace(state_dict[k].to(v.device)).realize()
if consume: del state_dict[k]
def tar_extract(fn:os.PathLike) -> Dict[str, Tensor]:
"""
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
```python
tensors = nn.state.tar_extract("archive.tar")
```
"""
t = Tensor(pathlib.Path(fn))
with tarfile.open(fn, "r") as tar:
return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
# torch support!
def torch_load(fn:str) -> Dict[str, Tensor]:

View File

@@ -1,7 +1,7 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import dataclasses
import time, math, itertools, functools, struct, sys, inspect
import time, math, itertools, functools, struct, sys, inspect, pathlib
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
from collections import defaultdict
@@ -105,10 +105,11 @@ class Tensor:
training: ClassVar[bool] = False
no_grad: ClassVar[bool] = False
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable, pathlib.Path],
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
if dtype is not None: dtype = to_dtype(dtype)
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
# tensors can have gradients if you have called .backward
@@ -136,6 +137,9 @@ class Tensor:
elif isinstance(data, np.ndarray):
if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
elif isinstance(data, pathlib.Path):
dtype = dtype or dtypes.uint8
data = _metaop(MetaOps.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
# by this point, it has to be a LazyBuffer
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):