mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
actual tinyfs device (#12620)
This commit is contained in:
@@ -121,7 +121,7 @@ class BufferCopy(Runner):
|
||||
getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk
|
||||
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
|
||||
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
|
||||
elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'):
|
||||
elif (src.device.startswith("DISK") or src.device.startswith("TINYFS")) and hasattr(dest.allocator, '_as_buffer'):
|
||||
# fast(ish) path, uses readinto in diskbuffers
|
||||
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
|
||||
else:
|
||||
|
||||
137
tinygrad/runtime/ops_tinyfs.py
Normal file
137
tinygrad/runtime/ops_tinyfs.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import socket, uuid, json, asyncio, threading
|
||||
from contextlib import asynccontextmanager
|
||||
from tinygrad.device import Compiled, Allocator
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
|
||||
TINYFS_ENDPOINT = getenv("TINYFS_ENDPOINT", "localhost:6767")
|
||||
CHUNK_SIZE = 2**20
|
||||
|
||||
class TinyFSDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
self.op = device[len("tinyfs:"):].upper()
|
||||
super().__init__(device, TinyFSAllocator(self), None, None, None)
|
||||
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.sock.connect((TINYFS_ENDPOINT.rsplit(":", 1)[0], int(TINYFS_ENDPOINT.rsplit(":", 1)[1])))
|
||||
self.sfile = self.sock.makefile("rwb")
|
||||
|
||||
# fetch node info
|
||||
self.sfile.write(b"INFO\r\n")
|
||||
self.sfile.flush()
|
||||
info = self.sfile.readline()
|
||||
self.node_info = json.loads(info)
|
||||
if DEBUG >= 2: print(f"nodes: {self.node_info}")
|
||||
|
||||
# spawn thread for async copyout
|
||||
self.start_event = threading.Event()
|
||||
self.t = threading.Thread(target=self._start_thread, daemon=True)
|
||||
self.t.start()
|
||||
self.start_event.wait()
|
||||
|
||||
# connection pools
|
||||
self.conn_pools: dict[str, asyncio.Queue] = {}
|
||||
self.conn_pools_lock = asyncio.Lock()
|
||||
|
||||
def finalize(self):
|
||||
self.sfile.close()
|
||||
|
||||
for pool in self.conn_pools.values():
|
||||
while not pool.empty():
|
||||
_, w = pool.get_nowait()
|
||||
w.close()
|
||||
asyncio.run_coroutine_threadsafe(w.wait_closed(), self.loop).result()
|
||||
|
||||
if hasattr(self, "loop"):
|
||||
self.loop.call_soon_threadsafe(self.loop.stop)
|
||||
self.t.join()
|
||||
|
||||
def _start_thread(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
self.start_event.set()
|
||||
self.loop.run_forever()
|
||||
self.loop.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self, loc):
|
||||
if loc not in self.conn_pools:
|
||||
await self.conn_pools_lock.acquire()
|
||||
if loc not in self.conn_pools:
|
||||
self.conn_pools[loc] = asyncio.Queue(nw:=getenv("ASYNC_COPY_WORKERS", 4))
|
||||
conn_tasks = [asyncio.open_connection(*self.node_info[loc][-1].rsplit(":", 1)) for _ in range(nw)]
|
||||
connections = await asyncio.gather(*conn_tasks)
|
||||
for reader, writer in connections: self.conn_pools[loc].put_nowait((reader, writer))
|
||||
self.conn_pools_lock.release()
|
||||
|
||||
reader, writer = await self.conn_pools[loc].get()
|
||||
try:
|
||||
yield reader, writer
|
||||
finally:
|
||||
await self.conn_pools[loc].put((reader, writer))
|
||||
|
||||
class TinyFSBuffer:
|
||||
def __init__(self, device:TinyFSDevice, size:int, offset=0, request_id=None, copyout_queue=None):
|
||||
self.device, self.size, self.offset = device, size, offset
|
||||
self.request_id: uuid.UUID|None = request_id
|
||||
self.copyout_queue = copyout_queue or []
|
||||
def __repr__(self): return f"<TinyFSBuffer size={self.size} offset={self.offset}>"
|
||||
|
||||
class TinyFSAllocator(Allocator[TinyFSDevice]):
|
||||
def _alloc(self, size, options):
|
||||
return TinyFSBuffer(self.dev, size)
|
||||
|
||||
def _copyin(self, dest:TinyFSBuffer, src:memoryview):
|
||||
if DEBUG >= 2: print(f"Copying in {dest.size} bytes to TINYFS:{dest.device.op}")
|
||||
self.dev.sfile.write(f"{dest.device.op}_IN {dest.size}\r\n".encode())
|
||||
|
||||
if dest.device.op == "STORE":
|
||||
self.dev.sfile.flush()
|
||||
dest.request_id = uuid.UUID(bytes=self.dev.sfile.read(16))
|
||||
if DEBUG >= 2: print(f"Request ID: {dest.request_id}")
|
||||
|
||||
self.dev.sfile.write(src)
|
||||
self.dev.sfile.flush()
|
||||
|
||||
if dest.device.op == "LOAD":
|
||||
locs = self.dev.sfile.readline()
|
||||
locs = json.loads(locs)
|
||||
|
||||
dest.copyout_queue = []
|
||||
for i, loc in enumerate(locs):
|
||||
dest.copyout_queue.append((i, loc, src[i*16:(i+1)*16]))
|
||||
|
||||
def _copyout(self, dest:memoryview, src:TinyFSBuffer):
|
||||
if DEBUG >= 2: print(f"Copying out {src.size} bytes from TINYFS:{src.device.op}")
|
||||
if src.device.op == "LOAD":
|
||||
asyncio.run_coroutine_threadsafe(self._copyout_async(dest, src), src.device.loop).result()
|
||||
else:
|
||||
self.dev.sfile.write(f"{src.device.op}_OUT {src.size} {src.request_id}\r\n".encode())
|
||||
self.dev.sfile.flush()
|
||||
src.request_id = uuid.UUID(bytes=self.dev.sfile.read(16))
|
||||
if DEBUG >= 2: print(f"Request ID: {src.request_id}")
|
||||
self.dev.sfile.readinto(dest)
|
||||
|
||||
async def _copyout_async(self, dest:memoryview, src:TinyFSBuffer):
|
||||
async def _worker(item):
|
||||
i, loc, h = item
|
||||
async with self.dev.connection(loc) as (reader, writer):
|
||||
ptr = i * CHUNK_SIZE
|
||||
size = min(len(dest[ptr:ptr+CHUNK_SIZE]), CHUNK_SIZE)
|
||||
|
||||
writer.write(f"CHUNK_OUT {size}\r\n".encode())
|
||||
writer.write(h)
|
||||
await writer.drain()
|
||||
|
||||
chunk = await reader.readexactly(size)
|
||||
|
||||
view = dest[ptr:ptr+len(chunk)]
|
||||
view[:] = chunk
|
||||
del view
|
||||
|
||||
workers = [asyncio.create_task(_worker(item)) for item in src.copyout_queue]
|
||||
await asyncio.gather(*workers)
|
||||
src.copyout_queue.clear()
|
||||
|
||||
def _offset(self, buf:TinyFSBuffer, size:int, offset:int):
|
||||
return TinyFSBuffer(buf.device, size, offset, buf.request_id, buf.copyout_queue)
|
||||
@@ -207,7 +207,7 @@ pm_cleanups = pm_mops+PatternMatcher([
|
||||
])
|
||||
|
||||
def late_buffer_view(t:UOp, b:UOp):
|
||||
if isinstance(b.device, str) and b.device.startswith("DISK"):
|
||||
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
|
||||
rngs = b.src[1:]
|
||||
size = prod(shape := [int(r.vmax+1) for r in rngs])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user