mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
feat: one request per device (#12698)
This commit is contained in:
@@ -32,6 +32,9 @@ class TinyFSDevice(Compiled):
|
||||
self.conn_pools: dict[str, asyncio.Queue] = {}
|
||||
self.conn_pools_lock = asyncio.Lock()
|
||||
|
||||
# current request
|
||||
self.request_id = uuid.UUID(int=0)
|
||||
|
||||
def finalize(self):
|
||||
self.sfile.close()
|
||||
|
||||
@@ -71,9 +74,8 @@ class TinyFSDevice(Compiled):
|
||||
await self.conn_pools[loc].put((reader, writer))
|
||||
|
||||
class TinyFSBuffer:
|
||||
def __init__(self, device:TinyFSDevice, size:int, offset=0, request_id=None, copyout_queue=None):
|
||||
def __init__(self, device:TinyFSDevice, size:int, offset=0, copyout_queue=None):
|
||||
self.device, self.size, self.offset = device, size, offset
|
||||
self.request_id: uuid.UUID|None = request_id
|
||||
self.copyout_queue = copyout_queue or []
|
||||
def __repr__(self): return f"<TinyFSBuffer size={self.size} offset={self.offset}>"
|
||||
|
||||
@@ -87,8 +89,8 @@ class TinyFSAllocator(Allocator[TinyFSDevice]):
|
||||
|
||||
if dest.device.op == "STORE":
|
||||
self.dev.sfile.flush()
|
||||
dest.request_id = uuid.UUID(bytes=self.dev.sfile.read(16))
|
||||
if DEBUG >= 2: print(f"Request ID: {dest.request_id}")
|
||||
self.dev.request_id = uuid.UUID(bytes=self.dev.sfile.read(16))
|
||||
if DEBUG >= 2: print(f"Request ID: {self.dev.request_id}")
|
||||
|
||||
self.dev.sfile.write(src)
|
||||
self.dev.sfile.flush()
|
||||
@@ -99,17 +101,15 @@ class TinyFSAllocator(Allocator[TinyFSDevice]):
|
||||
|
||||
dest.copyout_queue = []
|
||||
for i, loc in enumerate(locs):
|
||||
dest.copyout_queue.append((i, loc, src[i*16:(i+1)*16]))
|
||||
dest.copyout_queue.append((i, loc, src[i*16:(i+1)*16].tobytes()))
|
||||
|
||||
def _copyout(self, dest:memoryview, src:TinyFSBuffer):
|
||||
if DEBUG >= 2: print(f"Copying out {src.size} bytes from TINYFS:{src.device.op}")
|
||||
if src.device.op == "LOAD":
|
||||
asyncio.run_coroutine_threadsafe(self._copyout_async(dest, src), src.device.loop).result()
|
||||
else:
|
||||
self.dev.sfile.write(f"{src.device.op}_OUT {src.size} {src.request_id}\r\n".encode())
|
||||
self.dev.sfile.write(f"{src.device.op}_OUT {src.size} {self.dev.request_id}\r\n".encode())
|
||||
self.dev.sfile.flush()
|
||||
src.request_id = uuid.UUID(bytes=self.dev.sfile.read(16))
|
||||
if DEBUG >= 2: print(f"Request ID: {src.request_id}")
|
||||
self.dev.sfile.readinto(dest)
|
||||
|
||||
async def _copyout_async(self, dest:memoryview, src:TinyFSBuffer):
|
||||
@@ -131,7 +131,6 @@ class TinyFSAllocator(Allocator[TinyFSDevice]):
|
||||
|
||||
workers = [asyncio.create_task(_worker(item)) for item in src.copyout_queue]
|
||||
await asyncio.gather(*workers)
|
||||
src.copyout_queue.clear()
|
||||
|
||||
def _offset(self, buf:TinyFSBuffer, size:int, offset:int):
|
||||
return TinyFSBuffer(buf.device, size, offset, buf.request_id, buf.copyout_queue)
|
||||
return TinyFSBuffer(buf.device, size, offset, buf.copyout_queue)
|
||||
|
||||
Reference in New Issue
Block a user