diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 1f4b22fdc2..c40ec73252 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -1,8 +1,11 @@ +import os import pathlib, tempfile, unittest +import tarfile + import numpy as np from tinygrad import Tensor, Device, dtypes from tinygrad.dtype import DType -from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load +from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract from tinygrad.helpers import Timing, fetch, temp, CI from test.helpers import is_dtype_supported @@ -333,5 +336,114 @@ class TestDiskTensor(unittest.TestCase): on_dev = t.to(Device.DEFAULT).realize() np.testing.assert_equal(on_dev.numpy(), t.numpy()) +class TestTarExtract(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.test_files = { + 'file1.txt': b'Hello, World!', + 'file2.bin': b'\x00\x01\x02\x03\x04', + 'empty_file.txt': b'' + } + self.tar_path = os.path.join(self.test_dir, 'test.tar') + with tarfile.open(self.tar_path, 'w') as tar: + for filename, content in self.test_files.items(): + file_path = os.path.join(self.test_dir, filename) + with open(file_path, 'wb') as f: + f.write(content) + tar.add(file_path, arcname=filename) + + # Create invalid tar file + self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar') + with open(self.invalid_tar_path, 'wb') as f: + f.write(b'This is not a valid tar file') + + def tearDown(self): + for filename in self.test_files: + os.remove(os.path.join(self.test_dir, filename)) + os.remove(self.tar_path) + os.remove(self.invalid_tar_path) + os.rmdir(self.test_dir) + + def test_tar_extract_returns_dict(self): + result = tar_extract(self.tar_path) + self.assertIsInstance(result, dict) + + def test_tar_extract_correct_keys(self): + result = tar_extract(self.tar_path) + self.assertEqual(set(result.keys()), set(self.test_files.keys())) + + def test_tar_extract_content_size(self): + result = tar_extract(self.tar_path) + for filename, content in self.test_files.items(): + self.assertEqual(len(result[filename]), len(content)) + + def test_tar_extract_content_values(self): + result = tar_extract(self.tar_path) + for filename, content in self.test_files.items(): + np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8)) + + def test_tar_extract_empty_file(self): + result = tar_extract(self.tar_path) + self.assertEqual(len(result['empty_file.txt']), 0) + + def test_tar_extract_non_existent_file(self): + with self.assertRaises(FileNotFoundError): + tar_extract('non_existent_file.tar') + + def test_tar_extract_invalid_file(self): + with self.assertRaises(tarfile.ReadError): + tar_extract(self.invalid_tar_path) + +class TestPathTensor(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.test_file = pathlib.Path(self.temp_dir.name) / "test_file.bin" + self.test_data = np.arange(100, dtype=np.uint8).tobytes() + with open(self.test_file, "wb") as f: + f.write(self.test_data) + + def tearDown(self): + self.temp_dir.cleanup() + + def test_path_tensor_no_device(self): + t = Tensor(self.test_file) + self.assertEqual(t.shape, (100,)) + self.assertEqual(t.dtype, dtypes.uint8) + self.assertTrue(t.device.startswith("DISK:")) + np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8)) + + def test_path_tensor_with_device(self): + t = Tensor(self.test_file, device="CPU") + self.assertEqual(t.shape, (100,)) + self.assertEqual(t.dtype, dtypes.uint8) + self.assertEqual(t.device, "CPU") + np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8)) + + def test_path_tensor_empty_file(self): + empty_file = pathlib.Path(self.temp_dir.name) / "empty_file.bin" + empty_file.touch() + t = Tensor(empty_file) + self.assertEqual(t.shape, (0,)) + self.assertEqual(t.dtype, dtypes.uint8) + self.assertTrue(t.device.startswith("DISK:")) + + def test_path_tensor_non_existent_file(self): + non_existent_file = pathlib.Path(self.temp_dir.name) / "non_existent.bin" + with self.assertRaises(FileNotFoundError): + Tensor(non_existent_file) + + def test_path_tensor_with_dtype(self): + t = Tensor(self.test_file, dtype=dtypes.int16) + self.assertEqual(t.shape, (50,)) + self.assertEqual(t.dtype, dtypes.int16) + self.assertTrue(t.device.startswith("DISK:")) + np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.int16)) + + def test_path_tensor_copy_to_device(self): + t = Tensor(self.test_file) + t_cpu = t.to("CPU") + self.assertEqual(t_cpu.device, "CPU") + np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8)) + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 0ddd673caa..4aef4988d4 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -129,6 +129,18 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr else: v.replace(state_dict[k].to(v.device)).realize() if consume: del state_dict[k] +def tar_extract(fn:os.PathLike) -> Dict[str, Tensor]: + """ + Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values). + + ```python + tensors = nn.state.tar_extract("archive.tar") + ``` + """ + t = Tensor(pathlib.Path(fn)) + with tarfile.open(fn, "r") as tar: + return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE} + # torch support! def torch_load(fn:str) -> Dict[str, Tensor]: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 54b7d00aa5..fde1842bba 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,7 +1,7 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations import dataclasses -import time, math, itertools, functools, struct, sys, inspect +import time, math, itertools, functools, struct, sys, inspect, pathlib from contextlib import ContextDecorator from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set from collections import defaultdict @@ -105,10 +105,11 @@ class Tensor: training: ClassVar[bool] = False no_grad: ClassVar[bool] = False - def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable], + def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable, pathlib.Path], device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None): if dtype is not None: dtype = to_dtype(dtype) assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" + if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) # tensors can have gradients if you have called .backward @@ -136,6 +137,9 @@ class Tensor: elif isinstance(data, np.ndarray): if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item()) else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) + elif isinstance(data, pathlib.Path): + dtype = dtype or dtypes.uint8 + data = _metaop(MetaOps.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}") # by this point, it has to be a LazyBuffer if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):