mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Remote multihost (p2p transfer) (#10601)
This commit is contained in:
@@ -20,7 +20,7 @@ from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, Buffe
|
||||
# ***** API *****
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SessionKey: idx: int; nonce: str # noqa: E702
|
||||
class SessionKey: host: str; idx: int; nonce: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RemoteRequest: session: SessionKey|None = field(default=None, kw_only=True)
|
||||
@@ -32,14 +32,19 @@ class SessionFree(RemoteRequest): pass
|
||||
class RemoteProperties:
|
||||
real_device: str
|
||||
renderer: tuple[str, str, tuple[Any, ...]]
|
||||
offset_supported: bool
|
||||
graph_supported: bool
|
||||
graph_supports_multi: bool
|
||||
transfer_supported: bool
|
||||
offset_supported: bool
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetProperties(RemoteRequest): pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Event(RemoteRequest): event_session: SessionKey; event: int # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Wait(RemoteRequest): event: int
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
|
||||
|
||||
@@ -101,9 +106,9 @@ class GraphExec(RemoteRequest):
|
||||
wait: bool
|
||||
|
||||
# for safe deserialization
|
||||
eval_globals = {x.__name__:x for x in [SessionKey, SessionFree, RemoteProperties, GetProperties, BufferAlloc, BufferOffset, BufferFree, CopyIn,
|
||||
CopyOut, Transfer, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem, GraphAlloc, GraphFree, GraphExec,
|
||||
BufferSpec, UOp, Ops, dtypes]}
|
||||
eval_globals = {x.__name__:x for x in [SessionKey, SessionFree, RemoteProperties, GetProperties, Event, Wait, BufferAlloc, BufferOffset, BufferFree,
|
||||
CopyIn, CopyOut, Transfer, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem, GraphAlloc, GraphFree,
|
||||
GraphExec, BufferSpec, UOp, Ops, dtypes]}
|
||||
attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}}
|
||||
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
|
||||
ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)},
|
||||
@@ -143,6 +148,7 @@ class RemoteSession:
|
||||
programs: dict[tuple[str, str], Any] = field(default_factory=dict)
|
||||
graphs: dict[int, GraphRunner] = field(default_factory=dict)
|
||||
buffers: dict[int, Buffer] = field(default_factory=dict)
|
||||
events: defaultdict[int, asyncio.Event] = field(default_factory=functools.partial(defaultdict, asyncio.Event))
|
||||
|
||||
class RemoteHandler:
|
||||
def __init__(self, base_device: str):
|
||||
@@ -157,10 +163,10 @@ class RemoteHandler:
|
||||
key, value = hdr.split(':', 1)
|
||||
req_headers[key.lower()] = value.strip()
|
||||
req_body = await reader.readexactly(int(req_headers.get("content-length", "0")))
|
||||
res_status, res_body = self.handle(req_method, req_path, req_body)
|
||||
res_status, res_body = await self.handle(req_method, req_path, req_body)
|
||||
writer.write(f"HTTP/1.1 {res_status.value} {res_status.phrase}\r\nContent-Length: {len(res_body)}\r\n\r\n".encode() + res_body)
|
||||
|
||||
def handle(self, method:str, path:str, body:bytes) -> tuple[http.HTTPStatus, bytes]:
|
||||
async def handle(self, method:str, path:str, body:bytes) -> tuple[http.HTTPStatus, bytes]:
|
||||
status, ret = http.HTTPStatus.OK, b""
|
||||
if path == "/batch" and method == "POST":
|
||||
# TODO: streaming deserialize?
|
||||
@@ -175,11 +181,20 @@ class RemoteHandler:
|
||||
cls, args = dev.renderer.__reduce__()
|
||||
graph_cls = graph_class(Device[self.base_device])
|
||||
rp = RemoteProperties(
|
||||
real_device=dev.device, renderer=(cls.__module__, cls.__name__, args),
|
||||
real_device=dev.device, renderer=(cls.__module__, cls.__name__, args), offset_supported=hasattr(dev.allocator, '_offset'),
|
||||
graph_supported=graph_cls is not None, graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner),
|
||||
transfer_supported=hasattr(dev.allocator, '_transfer'), offset_supported=hasattr(dev.allocator, '_offset'),
|
||||
)
|
||||
ret = repr(rp).encode()
|
||||
case Event():
|
||||
if c.session == c.event_session:
|
||||
session.events[c.event].set()
|
||||
else:
|
||||
dev.synchronize() # wait for device to finish executing previous stuff
|
||||
# TODO: don't wait, just send
|
||||
RemoteConnection(c.event_session.host).q(Event(c.event_session, c.event, session=c.event_session), wait=True)
|
||||
case Wait():
|
||||
assert await session.events[c.event].wait()
|
||||
del session.events[c.event] # do not leak memory
|
||||
case BufferAlloc():
|
||||
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
|
||||
session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
|
||||
@@ -190,11 +205,20 @@ class RemoteHandler:
|
||||
case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash])))
|
||||
case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes)))
|
||||
case Transfer():
|
||||
dsession, ddev = self.sessions[c.dsession], Device[f"{self.base_device}:{unwrap(c.dsession).idx}"]
|
||||
dbuf, sbuf = dsession.buffers[c.dbuffer_num], session.buffers[c.buffer_num]
|
||||
assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}"
|
||||
assert hasattr(ddev.allocator, '_transfer'), f"Device {ddev.device} doesn't support transfers"
|
||||
ddev.allocator._transfer(dbuf._buf, sbuf._buf, dbuf.nbytes, dest_dev=ddev, src_dev=dev)
|
||||
if c.dsession.host == unwrap(c.session).host:
|
||||
dsession, ddev = self.sessions[c.dsession], Device[f"{self.base_device}:{unwrap(c.dsession).idx}"]
|
||||
dbuf, sbuf = dsession.buffers[c.dbuffer_num], session.buffers[c.buffer_num]
|
||||
if hasattr(ddev.allocator, '_transfer'):
|
||||
assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}"
|
||||
ddev.allocator._transfer(dbuf._buf, sbuf._buf, dbuf.nbytes, dest_dev=ddev, src_dev=dev)
|
||||
else:
|
||||
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
|
||||
dbuf.copyin(data)
|
||||
else:
|
||||
conn = RemoteConnection(c.dsession.host)
|
||||
sbuf = session.buffers[c.buffer_num]
|
||||
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
|
||||
conn.q(CopyIn(c.dbuffer_num, conn.req.h(data), session=c.dsession), wait=True)
|
||||
case ProgramAlloc():
|
||||
lib = dev.compiler.compile_cached(req._h[c.datahash].decode())
|
||||
session.programs[(c.name, c.datahash)] = dev.runtime(c.name, lib)
|
||||
@@ -256,11 +280,14 @@ class RemoteAllocator(Allocator['RemoteDevice']):
|
||||
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
|
||||
dest[:] = resp
|
||||
def _transfer(self, dest, src, sz, src_dev, dest_dev):
|
||||
if dest_dev.properties.transfer_supported and src_dev.conn == dest_dev.conn:
|
||||
src_dev.q(Transfer(src, dest_dev.session, dest))
|
||||
else:
|
||||
src_dev.allocator._copyout(tmp:=memoryview(bytearray(sz)), src)
|
||||
dest_dev.allocator._copyin(dest, tmp)
|
||||
if dest_dev.conn != src_dev.conn:
|
||||
dest_dev.q(Event(src_dev.session, start_event:=next(src_dev.event_num)))
|
||||
src_dev.q(Wait(start_event))
|
||||
src_dev.q(Transfer(src, dest_dev.session, dest))
|
||||
if dest_dev.conn != src_dev.conn:
|
||||
src_dev.q(Event(dest_dev.session, end_event:=next(dest_dev.event_num)))
|
||||
dest_dev.q(Wait(end_event))
|
||||
if DEBUG >= 2: dest_dev.conn.batch_submit()
|
||||
def _dyn_offset(self, opaque:int, size:int, offset:int) -> int:
|
||||
self.dev.q(BufferOffset(buffer_num:=next(self.dev.buffer_num), size, offset, opaque))
|
||||
return buffer_num
|
||||
@@ -282,6 +309,8 @@ class RemoteProgram:
|
||||
|
||||
@functools.cache
|
||||
class RemoteConnection:
|
||||
all: dict[RemoteConnection, None] = {} # dict instead of set for deterministic ordering
|
||||
|
||||
def __init__(self, host:str):
|
||||
if DEBUG >= 1: print(f"remote with host {host}")
|
||||
while 1:
|
||||
@@ -293,15 +322,24 @@ class RemoteConnection:
|
||||
print(e)
|
||||
time.sleep(0.1)
|
||||
self.req: BatchRequest = BatchRequest()
|
||||
RemoteConnection.all[self] = None
|
||||
|
||||
def q(self, x:RemoteRequest, wait:bool=False):
|
||||
self.req.q(x)
|
||||
if wait: return self.batch_submit()
|
||||
|
||||
def batch_submit(self):
|
||||
data = self.req.serialize()
|
||||
with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=3):
|
||||
self.conn.request("POST", "/batch", data)
|
||||
response = self.conn.getresponse()
|
||||
assert response.status == 200, f"POST /batch failed: {response}"
|
||||
ret = response.read()
|
||||
self.req = BatchRequest()
|
||||
conns = RemoteConnection.all.keys()
|
||||
datas = {conn: conn.req.serialize() for conn in conns}
|
||||
reqs, hashes, hash_datas = sum(len(c.req._q) for c in conns), sum(len(c.req._h) for c in conns), sum(len(data) for data in datas.values())
|
||||
with Timing(f"*** send {reqs:-3d} requests {hashes:-3d} hashes with len {hash_datas/1024:.2f} kB in ", enabled=DEBUG>=3):
|
||||
for conn,data in datas.items(): conn.conn.request("POST", "/batch", data)
|
||||
for conn in datas.keys():
|
||||
response = conn.conn.getresponse()
|
||||
assert response.status == 200, f"POST /batch failed: {response}"
|
||||
resp = response.read()
|
||||
if conn == self: ret = resp
|
||||
conn.req = BatchRequest()
|
||||
return ret
|
||||
|
||||
def parse_hosts(hs:str) -> list[tuple[str, int]]|LazySeq[tuple[str, int]]:
|
||||
@@ -316,12 +354,13 @@ class RemoteDevice(Compiled):
|
||||
host, idx = RemoteDevice.devices[int(device.split(":")[1]) if ":" in device else 0]
|
||||
|
||||
# connection is shared between sessions on the same host
|
||||
self.conn: RemoteConnection = RemoteConnection(host or RemoteDevice.local_server())
|
||||
self.session: SessionKey = SessionKey(host or RemoteDevice.local_server(), idx, binascii.hexlify(os.urandom(0x10)).decode())
|
||||
self.conn: RemoteConnection = RemoteConnection(self.session.host)
|
||||
|
||||
# state for the session
|
||||
self.session = SessionKey(idx, binascii.hexlify(os.urandom(0x10)).decode())
|
||||
self.buffer_num: Iterator[int] = itertools.count(0)
|
||||
self.graph_num: Iterator[int] = itertools.count(0)
|
||||
self.event_num: Iterator[int] = itertools.count(0)
|
||||
|
||||
self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body)
|
||||
if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")
|
||||
@@ -339,9 +378,7 @@ class RemoteDevice(Compiled):
|
||||
def finalize(self):
|
||||
with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)
|
||||
|
||||
def q(self, x:RemoteRequest, wait:bool=False):
|
||||
self.conn.req.q(replace(x, session=self.session))
|
||||
if wait: return self.conn.batch_submit()
|
||||
def q(self, x:RemoteRequest, wait:bool=False): return self.conn.q(replace(x, session=self.session), wait=wait)
|
||||
|
||||
@functools.cache
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user