Remote multihost (p2p transfer) (#10601)

This commit is contained in:
uuuvn
2025-06-23 18:47:29 +00:00
committed by GitHub
parent 42b1c9625b
commit 4e2c9e36c7

View File

@@ -20,7 +20,7 @@ from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, Buffe
# ***** API *****
@dataclass(frozen=True)
class SessionKey: idx: int; nonce: str # noqa: E702
class SessionKey: host: str; idx: int; nonce: str # noqa: E702
@dataclass(frozen=True)
class RemoteRequest: session: SessionKey|None = field(default=None, kw_only=True)
@@ -32,14 +32,19 @@ class SessionFree(RemoteRequest): pass
class RemoteProperties:
real_device: str
renderer: tuple[str, str, tuple[Any, ...]]
offset_supported: bool
graph_supported: bool
graph_supports_multi: bool
transfer_supported: bool
offset_supported: bool
@dataclass(frozen=True)
class GetProperties(RemoteRequest): pass
@dataclass(frozen=True)
class Event(RemoteRequest): event_session: SessionKey; event: int # noqa: E702
@dataclass(frozen=True)
class Wait(RemoteRequest): event: int
@dataclass(frozen=True)
class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
@@ -101,9 +106,9 @@ class GraphExec(RemoteRequest):
wait: bool
# for safe deserialization
eval_globals = {x.__name__:x for x in [SessionKey, SessionFree, RemoteProperties, GetProperties, BufferAlloc, BufferOffset, BufferFree, CopyIn,
CopyOut, Transfer, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem, GraphAlloc, GraphFree, GraphExec,
BufferSpec, UOp, Ops, dtypes]}
eval_globals = {x.__name__:x for x in [SessionKey, SessionFree, RemoteProperties, GetProperties, Event, Wait, BufferAlloc, BufferOffset, BufferFree,
CopyIn, CopyOut, Transfer, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem, GraphAlloc, GraphFree,
GraphExec, BufferSpec, UOp, Ops, dtypes]}
attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}}
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)},
@@ -143,6 +148,7 @@ class RemoteSession:
programs: dict[tuple[str, str], Any] = field(default_factory=dict)
graphs: dict[int, GraphRunner] = field(default_factory=dict)
buffers: dict[int, Buffer] = field(default_factory=dict)
events: defaultdict[int, asyncio.Event] = field(default_factory=functools.partial(defaultdict, asyncio.Event))
class RemoteHandler:
def __init__(self, base_device: str):
@@ -157,10 +163,10 @@ class RemoteHandler:
key, value = hdr.split(':', 1)
req_headers[key.lower()] = value.strip()
req_body = await reader.readexactly(int(req_headers.get("content-length", "0")))
res_status, res_body = self.handle(req_method, req_path, req_body)
res_status, res_body = await self.handle(req_method, req_path, req_body)
writer.write(f"HTTP/1.1 {res_status.value} {res_status.phrase}\r\nContent-Length: {len(res_body)}\r\n\r\n".encode() + res_body)
def handle(self, method:str, path:str, body:bytes) -> tuple[http.HTTPStatus, bytes]:
async def handle(self, method:str, path:str, body:bytes) -> tuple[http.HTTPStatus, bytes]:
status, ret = http.HTTPStatus.OK, b""
if path == "/batch" and method == "POST":
# TODO: streaming deserialize?
@@ -175,11 +181,20 @@ class RemoteHandler:
cls, args = dev.renderer.__reduce__()
graph_cls = graph_class(Device[self.base_device])
rp = RemoteProperties(
real_device=dev.device, renderer=(cls.__module__, cls.__name__, args),
real_device=dev.device, renderer=(cls.__module__, cls.__name__, args), offset_supported=hasattr(dev.allocator, '_offset'),
graph_supported=graph_cls is not None, graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner),
transfer_supported=hasattr(dev.allocator, '_transfer'), offset_supported=hasattr(dev.allocator, '_offset'),
)
ret = repr(rp).encode()
case Event():
if c.session == c.event_session:
session.events[c.event].set()
else:
dev.synchronize() # wait for device to finish executing previous stuff
# TODO: don't wait, just send
RemoteConnection(c.event_session.host).q(Event(c.event_session, c.event, session=c.event_session), wait=True)
case Wait():
assert await session.events[c.event].wait()
del session.events[c.event] # do not leak memory
case BufferAlloc():
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
@@ -190,11 +205,20 @@ class RemoteHandler:
case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash])))
case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes)))
case Transfer():
dsession, ddev = self.sessions[c.dsession], Device[f"{self.base_device}:{unwrap(c.dsession).idx}"]
dbuf, sbuf = dsession.buffers[c.dbuffer_num], session.buffers[c.buffer_num]
assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}"
assert hasattr(ddev.allocator, '_transfer'), f"Device {ddev.device} doesn't support transfers"
ddev.allocator._transfer(dbuf._buf, sbuf._buf, dbuf.nbytes, dest_dev=ddev, src_dev=dev)
if c.dsession.host == unwrap(c.session).host:
dsession, ddev = self.sessions[c.dsession], Device[f"{self.base_device}:{unwrap(c.dsession).idx}"]
dbuf, sbuf = dsession.buffers[c.dbuffer_num], session.buffers[c.buffer_num]
if hasattr(ddev.allocator, '_transfer'):
assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}"
ddev.allocator._transfer(dbuf._buf, sbuf._buf, dbuf.nbytes, dest_dev=ddev, src_dev=dev)
else:
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
dbuf.copyin(data)
else:
conn = RemoteConnection(c.dsession.host)
sbuf = session.buffers[c.buffer_num]
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
conn.q(CopyIn(c.dbuffer_num, conn.req.h(data), session=c.dsession), wait=True)
case ProgramAlloc():
lib = dev.compiler.compile_cached(req._h[c.datahash].decode())
session.programs[(c.name, c.datahash)] = dev.runtime(c.name, lib)
@@ -256,11 +280,14 @@ class RemoteAllocator(Allocator['RemoteDevice']):
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
dest[:] = resp
def _transfer(self, dest, src, sz, src_dev, dest_dev):
if dest_dev.properties.transfer_supported and src_dev.conn == dest_dev.conn:
src_dev.q(Transfer(src, dest_dev.session, dest))
else:
src_dev.allocator._copyout(tmp:=memoryview(bytearray(sz)), src)
dest_dev.allocator._copyin(dest, tmp)
if dest_dev.conn != src_dev.conn:
dest_dev.q(Event(src_dev.session, start_event:=next(src_dev.event_num)))
src_dev.q(Wait(start_event))
src_dev.q(Transfer(src, dest_dev.session, dest))
if dest_dev.conn != src_dev.conn:
src_dev.q(Event(dest_dev.session, end_event:=next(dest_dev.event_num)))
dest_dev.q(Wait(end_event))
if DEBUG >= 2: dest_dev.conn.batch_submit()
def _dyn_offset(self, opaque:int, size:int, offset:int) -> int:
self.dev.q(BufferOffset(buffer_num:=next(self.dev.buffer_num), size, offset, opaque))
return buffer_num
@@ -282,6 +309,8 @@ class RemoteProgram:
@functools.cache
class RemoteConnection:
all: dict[RemoteConnection, None] = {} # dict instead of set for deterministic ordering
def __init__(self, host:str):
if DEBUG >= 1: print(f"remote with host {host}")
while 1:
@@ -293,15 +322,24 @@ class RemoteConnection:
print(e)
time.sleep(0.1)
self.req: BatchRequest = BatchRequest()
RemoteConnection.all[self] = None
def q(self, x:RemoteRequest, wait:bool=False):
self.req.q(x)
if wait: return self.batch_submit()
def batch_submit(self):
data = self.req.serialize()
with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=3):
self.conn.request("POST", "/batch", data)
response = self.conn.getresponse()
assert response.status == 200, f"POST /batch failed: {response}"
ret = response.read()
self.req = BatchRequest()
conns = RemoteConnection.all.keys()
datas = {conn: conn.req.serialize() for conn in conns}
reqs, hashes, hash_datas = sum(len(c.req._q) for c in conns), sum(len(c.req._h) for c in conns), sum(len(data) for data in datas.values())
with Timing(f"*** send {reqs:-3d} requests {hashes:-3d} hashes with len {hash_datas/1024:.2f} kB in ", enabled=DEBUG>=3):
for conn,data in datas.items(): conn.conn.request("POST", "/batch", data)
for conn in datas.keys():
response = conn.conn.getresponse()
assert response.status == 200, f"POST /batch failed: {response}"
resp = response.read()
if conn == self: ret = resp
conn.req = BatchRequest()
return ret
def parse_hosts(hs:str) -> list[tuple[str, int]]|LazySeq[tuple[str, int]]:
@@ -316,12 +354,13 @@ class RemoteDevice(Compiled):
host, idx = RemoteDevice.devices[int(device.split(":")[1]) if ":" in device else 0]
# connection is shared between sessions on the same host
self.conn: RemoteConnection = RemoteConnection(host or RemoteDevice.local_server())
self.session: SessionKey = SessionKey(host or RemoteDevice.local_server(), idx, binascii.hexlify(os.urandom(0x10)).decode())
self.conn: RemoteConnection = RemoteConnection(self.session.host)
# state for the session
self.session = SessionKey(idx, binascii.hexlify(os.urandom(0x10)).decode())
self.buffer_num: Iterator[int] = itertools.count(0)
self.graph_num: Iterator[int] = itertools.count(0)
self.event_num: Iterator[int] = itertools.count(0)
self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body)
if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")
@@ -339,9 +378,7 @@ class RemoteDevice(Compiled):
def finalize(self):
with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)
def q(self, x:RemoteRequest, wait:bool=False):
self.conn.req.q(replace(x, session=self.session))
if wait: return self.conn.batch_submit()
def q(self, x:RemoteRequest, wait:bool=False): return self.conn.q(replace(x, session=self.session), wait=wait)
@functools.cache
@staticmethod