feat: Tensor.{load, store} (#12629)

This commit is contained in:
wozeparrot
2025-10-13 08:04:41 -07:00
committed by GitHub
parent 0f776c6e46
commit 47e0c43976
3 changed files with 78 additions and 3 deletions

22
test/unit/test_tinyfs.py Normal file
View File

@@ -0,0 +1,22 @@
import unittest
from tinygrad import Tensor
class TestLoadStore(unittest.TestCase):
def test_load_shape(self):
t = Tensor(bytes(16)).load(1024).kernelize()
assert t.shape == (1024,), t.shape
def test_store_shape(self):
t = Tensor.zeros(1024).store().kernelize()
assert t.shape == (16,), t.shape
def test_load_large_shape(self):
t = Tensor(bytes(16)).load(10_000_000).kernelize()
assert t.shape == (10_000_000,), t.shape
def test_store_large_shape(self):
t = Tensor.zeros(10_000_000).store().kernelize()
assert t.shape == (16,), t.shape
if __name__ == "__main__":
unittest.main()

View File

@@ -2,9 +2,9 @@ import socket, uuid, json, asyncio, threading
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from tinygrad.device import Compiled, Allocator from tinygrad.device import Compiled, Allocator
from tinygrad.helpers import DEBUG, getenv from tinygrad.helpers import DEBUG, getenv
from tinygrad import Tensor
TINYFS_ENDPOINT = getenv("TINYFS_ENDPOINT", "localhost:6767") TINYFS_ENDPOINT = getenv("TINYFS_ENDPOINT", "localhost:6767")
CHUNK_SIZE = 2**20
class TinyFSDevice(Compiled): class TinyFSDevice(Compiled):
def __init__(self, device:str): def __init__(self, device:str):
@@ -116,8 +116,8 @@ class TinyFSAllocator(Allocator[TinyFSDevice]):
async def _worker(item): async def _worker(item):
i, loc, h = item i, loc, h = item
async with self.dev.connection(loc) as (reader, writer): async with self.dev.connection(loc) as (reader, writer):
ptr = i * CHUNK_SIZE ptr = i * Tensor.CHUNK_SIZE
size = min(len(dest[ptr:ptr+CHUNK_SIZE]), CHUNK_SIZE) size = min(len(dest[ptr:ptr+Tensor.CHUNK_SIZE]), Tensor.CHUNK_SIZE)
writer.write(f"CHUNK_OUT {size}\r\n".encode()) writer.write(f"CHUNK_OUT {size}\r\n".encode())
writer.write(h) writer.write(h)

View File

@@ -411,6 +411,59 @@ class Tensor(MathTrait):
""" """
return self.replace(self.shard(devices, axis)) return self.replace(self.shard(devices, axis))
CHUNK_SIZE = 2**20
def load(self, size:int) -> Tensor:
"""
Load a tensor from storage.
self should be a tensor of the hash to load
"""
# TODO: this should work locally as well
assert self.dtype == dtypes.uint8, "hash is expected to be uint8"
h = self.contiguous().flatten()
assert h.shape[0] == 16, "expected hash"
base_chunks = math.ceil(size / Tensor.CHUNK_SIZE)
tree_depth = math.ceil(math.log(base_chunks, Tensor.CHUNK_SIZE // 16))
data, level_chunks = h, 0
for i in reversed(range(tree_depth + 1)):
data = data.to("tinyfs:load")
# if not last level, its still hashes
if i > 0 or tree_depth == 0:
level_chunks = max(1, math.ceil(base_chunks / (Tensor.CHUNK_SIZE // 16)**(i-1)))
pad_amt = 16 * level_chunks
else: pad_amt = Tensor.CHUNK_SIZE * level_chunks
if (tsize := data.shape[0]) < pad_amt: data = data.pad((0, pad_amt - tsize))
data = data[:pad_amt].contiguous()
if i != 0: data = data.to(self.device)
return data[:size]
def store(self) -> Tensor:
"""
Store a tensor to storage.
"""
# TODO: this should work locally as well
data = self.contiguous().flatten().bitcast(dtypes.uint8)
# pad to a multiple of 1mb
if (tsize := data.shape[0]) % Tensor.CHUNK_SIZE != 0: data = data.pad((0, Tensor.CHUNK_SIZE - tsize % Tensor.CHUNK_SIZE))
size = data.shape[0]
base_chunks = math.ceil(size / Tensor.CHUNK_SIZE)
tree_depth = math.ceil(math.log(base_chunks, Tensor.CHUNK_SIZE // 16))
to_device = "CPU" if isinstance(self.device, str) and self.device.startswith("DISK") else self.device
level_chunks = base_chunks
for _ in range(tree_depth + 1):
data = data.to("tinyfs:store")[:level_chunks * 16].contiguous().to(to_device)
if (tsize := data.shape[0]) % Tensor.CHUNK_SIZE != 0: data = data.pad((0, Tensor.CHUNK_SIZE - tsize % Tensor.CHUNK_SIZE))
level_chunks = math.ceil(data.shape[0] / Tensor.CHUNK_SIZE)
return data[:16].contiguous()
@staticmethod @staticmethod
def from_uop(y:UOp, **kwargs) -> Tensor: def from_uop(y:UOp, **kwargs) -> Tensor:
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False)