mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Write tar_extract (#6180)
* Add tar_extract * Add tar_extract tests * Fix dtype for initialization from path * Tests for path initialization * rm print --------- Co-authored-by: Maximilian Weichart <maximilian.weichart@icloud.com>
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import os
|
||||
import pathlib, tempfile, unittest
|
||||
import tarfile
|
||||
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract
|
||||
from tinygrad.helpers import Timing, fetch, temp, CI
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
@@ -333,5 +336,114 @@ class TestDiskTensor(unittest.TestCase):
|
||||
on_dev = t.to(Device.DEFAULT).realize()
|
||||
np.testing.assert_equal(on_dev.numpy(), t.numpy())
|
||||
|
||||
class TestTarExtract(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.test_files = {
|
||||
'file1.txt': b'Hello, World!',
|
||||
'file2.bin': b'\x00\x01\x02\x03\x04',
|
||||
'empty_file.txt': b''
|
||||
}
|
||||
self.tar_path = os.path.join(self.test_dir, 'test.tar')
|
||||
with tarfile.open(self.tar_path, 'w') as tar:
|
||||
for filename, content in self.test_files.items():
|
||||
file_path = os.path.join(self.test_dir, filename)
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
tar.add(file_path, arcname=filename)
|
||||
|
||||
# Create invalid tar file
|
||||
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
|
||||
with open(self.invalid_tar_path, 'wb') as f:
|
||||
f.write(b'This is not a valid tar file')
|
||||
|
||||
def tearDown(self):
|
||||
for filename in self.test_files:
|
||||
os.remove(os.path.join(self.test_dir, filename))
|
||||
os.remove(self.tar_path)
|
||||
os.remove(self.invalid_tar_path)
|
||||
os.rmdir(self.test_dir)
|
||||
|
||||
def test_tar_extract_returns_dict(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertIsInstance(result, dict)
|
||||
|
||||
def test_tar_extract_correct_keys(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
|
||||
|
||||
def test_tar_extract_content_size(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
for filename, content in self.test_files.items():
|
||||
self.assertEqual(len(result[filename]), len(content))
|
||||
|
||||
def test_tar_extract_content_values(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
for filename, content in self.test_files.items():
|
||||
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
|
||||
|
||||
def test_tar_extract_empty_file(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertEqual(len(result['empty_file.txt']), 0)
|
||||
|
||||
def test_tar_extract_non_existent_file(self):
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
tar_extract('non_existent_file.tar')
|
||||
|
||||
def test_tar_extract_invalid_file(self):
|
||||
with self.assertRaises(tarfile.ReadError):
|
||||
tar_extract(self.invalid_tar_path)
|
||||
|
||||
class TestPathTensor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.test_file = pathlib.Path(self.temp_dir.name) / "test_file.bin"
|
||||
self.test_data = np.arange(100, dtype=np.uint8).tobytes()
|
||||
with open(self.test_file, "wb") as f:
|
||||
f.write(self.test_data)
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_path_tensor_no_device(self):
|
||||
t = Tensor(self.test_file)
|
||||
self.assertEqual(t.shape, (100,))
|
||||
self.assertEqual(t.dtype, dtypes.uint8)
|
||||
self.assertTrue(t.device.startswith("DISK:"))
|
||||
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
def test_path_tensor_with_device(self):
|
||||
t = Tensor(self.test_file, device="CPU")
|
||||
self.assertEqual(t.shape, (100,))
|
||||
self.assertEqual(t.dtype, dtypes.uint8)
|
||||
self.assertEqual(t.device, "CPU")
|
||||
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
def test_path_tensor_empty_file(self):
|
||||
empty_file = pathlib.Path(self.temp_dir.name) / "empty_file.bin"
|
||||
empty_file.touch()
|
||||
t = Tensor(empty_file)
|
||||
self.assertEqual(t.shape, (0,))
|
||||
self.assertEqual(t.dtype, dtypes.uint8)
|
||||
self.assertTrue(t.device.startswith("DISK:"))
|
||||
|
||||
def test_path_tensor_non_existent_file(self):
|
||||
non_existent_file = pathlib.Path(self.temp_dir.name) / "non_existent.bin"
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
Tensor(non_existent_file)
|
||||
|
||||
def test_path_tensor_with_dtype(self):
|
||||
t = Tensor(self.test_file, dtype=dtypes.int16)
|
||||
self.assertEqual(t.shape, (50,))
|
||||
self.assertEqual(t.dtype, dtypes.int16)
|
||||
self.assertTrue(t.device.startswith("DISK:"))
|
||||
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.int16))
|
||||
|
||||
def test_path_tensor_copy_to_device(self):
|
||||
t = Tensor(self.test_file)
|
||||
t_cpu = t.to("CPU")
|
||||
self.assertEqual(t_cpu.device, "CPU")
|
||||
np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -129,6 +129,18 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
|
||||
else: v.replace(state_dict[k].to(v.device)).realize()
|
||||
if consume: del state_dict[k]
|
||||
|
||||
def tar_extract(fn:os.PathLike) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
|
||||
|
||||
```python
|
||||
tensors = nn.state.tar_extract("archive.tar")
|
||||
```
|
||||
"""
|
||||
t = Tensor(pathlib.Path(fn))
|
||||
with tarfile.open(fn, "r") as tar:
|
||||
return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
|
||||
|
||||
# torch support!
|
||||
|
||||
def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import time, math, itertools, functools, struct, sys, inspect
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
|
||||
from collections import defaultdict
|
||||
@@ -105,10 +105,11 @@ class Tensor:
|
||||
training: ClassVar[bool] = False
|
||||
no_grad: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
|
||||
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable, pathlib.Path],
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
||||
if dtype is not None: dtype = to_dtype(dtype)
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
||||
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
||||
|
||||
# tensors can have gradients if you have called .backward
|
||||
@@ -136,6 +137,9 @@ class Tensor:
|
||||
elif isinstance(data, np.ndarray):
|
||||
if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
|
||||
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
|
||||
elif isinstance(data, pathlib.Path):
|
||||
dtype = dtype or dtypes.uint8
|
||||
data = _metaop(MetaOps.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
|
||||
|
||||
# by this point, it has to be a LazyBuffer
|
||||
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
|
||||
|
||||
Reference in New Issue
Block a user