feat: Tensor.{load, store} (#12629)

This commit is contained in:
wozeparrot
2025-10-13 08:04:41 -07:00
committed by GitHub
parent 0f776c6e46
commit 47e0c43976
3 changed files with 78 additions and 3 deletions

22
test/unit/test_tinyfs.py Normal file
View File

@@ -0,0 +1,22 @@
import unittest
from tinygrad import Tensor
class TestLoadStore(unittest.TestCase):
def test_load_shape(self):
t = Tensor(bytes(16)).load(1024).kernelize()
assert t.shape == (1024,), t.shape
def test_store_shape(self):
t = Tensor.zeros(1024).store().kernelize()
assert t.shape == (16,), t.shape
def test_load_large_shape(self):
t = Tensor(bytes(16)).load(10_000_000).kernelize()
assert t.shape == (10_000_000,), t.shape
def test_store_large_shape(self):
t = Tensor.zeros(10_000_000).store().kernelize()
assert t.shape == (16,), t.shape
if __name__ == "__main__":
unittest.main()

View File

@@ -2,9 +2,9 @@ import socket, uuid, json, asyncio, threading
from contextlib import asynccontextmanager
from tinygrad.device import Compiled, Allocator
from tinygrad.helpers import DEBUG, getenv
from tinygrad import Tensor
TINYFS_ENDPOINT = getenv("TINYFS_ENDPOINT", "localhost:6767")
CHUNK_SIZE = 2**20
class TinyFSDevice(Compiled):
def __init__(self, device:str):
@@ -116,8 +116,8 @@ class TinyFSAllocator(Allocator[TinyFSDevice]):
async def _worker(item):
i, loc, h = item
async with self.dev.connection(loc) as (reader, writer):
ptr = i * CHUNK_SIZE
size = min(len(dest[ptr:ptr+CHUNK_SIZE]), CHUNK_SIZE)
ptr = i * Tensor.CHUNK_SIZE
size = min(len(dest[ptr:ptr+Tensor.CHUNK_SIZE]), Tensor.CHUNK_SIZE)
writer.write(f"CHUNK_OUT {size}\r\n".encode())
writer.write(h)

View File

@@ -411,6 +411,59 @@ class Tensor(MathTrait):
"""
return self.replace(self.shard(devices, axis))
CHUNK_SIZE = 2**20
def load(self, size:int) -> Tensor:
"""
Load a tensor from storage.
self should be a tensor of the hash to load
"""
# TODO: this should work locally as well
assert self.dtype == dtypes.uint8, "hash is expected to be uint8"
h = self.contiguous().flatten()
assert h.shape[0] == 16, "expected hash"
base_chunks = math.ceil(size / Tensor.CHUNK_SIZE)
tree_depth = math.ceil(math.log(base_chunks, Tensor.CHUNK_SIZE // 16))
data, level_chunks = h, 0
for i in reversed(range(tree_depth + 1)):
data = data.to("tinyfs:load")
# if not last level, its still hashes
if i > 0 or tree_depth == 0:
level_chunks = max(1, math.ceil(base_chunks / (Tensor.CHUNK_SIZE // 16)**(i-1)))
pad_amt = 16 * level_chunks
else: pad_amt = Tensor.CHUNK_SIZE * level_chunks
if (tsize := data.shape[0]) < pad_amt: data = data.pad((0, pad_amt - tsize))
data = data[:pad_amt].contiguous()
if i != 0: data = data.to(self.device)
return data[:size]
def store(self) -> Tensor:
"""
Store a tensor to storage.
"""
# TODO: this should work locally as well
data = self.contiguous().flatten().bitcast(dtypes.uint8)
# pad to a multiple of 1mb
if (tsize := data.shape[0]) % Tensor.CHUNK_SIZE != 0: data = data.pad((0, Tensor.CHUNK_SIZE - tsize % Tensor.CHUNK_SIZE))
size = data.shape[0]
base_chunks = math.ceil(size / Tensor.CHUNK_SIZE)
tree_depth = math.ceil(math.log(base_chunks, Tensor.CHUNK_SIZE // 16))
to_device = "CPU" if isinstance(self.device, str) and self.device.startswith("DISK") else self.device
level_chunks = base_chunks
for _ in range(tree_depth + 1):
data = data.to("tinyfs:store")[:level_chunks * 16].contiguous().to(to_device)
if (tsize := data.shape[0]) % Tensor.CHUNK_SIZE != 0: data = data.pad((0, Tensor.CHUNK_SIZE - tsize % Tensor.CHUNK_SIZE))
level_chunks = math.ceil(data.shape[0] / Tensor.CHUNK_SIZE)
return data[:16].contiguous()
@staticmethod
def from_uop(y:UOp, **kwargs) -> Tensor:
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False)