diff --git a/tinygrad/runtime/ops_remote.py b/tinygrad/runtime/ops_remote.py index cc11afaf12..560ee99c4f 100644 --- a/tinygrad/runtime/ops_remote.py +++ b/tinygrad/runtime/ops_remote.py @@ -20,7 +20,7 @@ from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, Buffe # ***** API ***** @dataclass(frozen=True) -class SessionKey: idx: int; nonce: str # noqa: E702 +class SessionKey: host: str; idx: int; nonce: str # noqa: E702 @dataclass(frozen=True) class RemoteRequest: session: SessionKey|None = field(default=None, kw_only=True) @@ -32,14 +32,19 @@ class SessionFree(RemoteRequest): pass class RemoteProperties: real_device: str renderer: tuple[str, str, tuple[Any, ...]] + offset_supported: bool graph_supported: bool graph_supports_multi: bool - transfer_supported: bool - offset_supported: bool @dataclass(frozen=True) class GetProperties(RemoteRequest): pass +@dataclass(frozen=True) +class Event(RemoteRequest): event_session: SessionKey; event: int # noqa: E702 + +@dataclass(frozen=True) +class Wait(RemoteRequest): event: int + @dataclass(frozen=True) class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702 @@ -101,9 +106,9 @@ class GraphExec(RemoteRequest): wait: bool # for safe deserialization -eval_globals = {x.__name__:x for x in [SessionKey, SessionFree, RemoteProperties, GetProperties, BufferAlloc, BufferOffset, BufferFree, CopyIn, - CopyOut, Transfer, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem, GraphAlloc, GraphFree, GraphExec, - BufferSpec, UOp, Ops, dtypes]} +eval_globals = {x.__name__:x for x in [SessionKey, SessionFree, RemoteProperties, GetProperties, Event, Wait, BufferAlloc, BufferOffset, BufferFree, + CopyIn, CopyOut, Transfer, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem, GraphAlloc, GraphFree, + GraphExec, BufferSpec, UOp, Ops, dtypes]} attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}} eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)), ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)}, @@ -143,6 +148,7 @@ class RemoteSession: programs: dict[tuple[str, str], Any] = field(default_factory=dict) graphs: dict[int, GraphRunner] = field(default_factory=dict) buffers: dict[int, Buffer] = field(default_factory=dict) + events: defaultdict[int, asyncio.Event] = field(default_factory=functools.partial(defaultdict, asyncio.Event)) class RemoteHandler: def __init__(self, base_device: str): @@ -157,10 +163,10 @@ class RemoteHandler: key, value = hdr.split(':', 1) req_headers[key.lower()] = value.strip() req_body = await reader.readexactly(int(req_headers.get("content-length", "0"))) - res_status, res_body = self.handle(req_method, req_path, req_body) + res_status, res_body = await self.handle(req_method, req_path, req_body) writer.write(f"HTTP/1.1 {res_status.value} {res_status.phrase}\r\nContent-Length: {len(res_body)}\r\n\r\n".encode() + res_body) - def handle(self, method:str, path:str, body:bytes) -> tuple[http.HTTPStatus, bytes]: + async def handle(self, method:str, path:str, body:bytes) -> tuple[http.HTTPStatus, bytes]: status, ret = http.HTTPStatus.OK, b"" if path == "/batch" and method == "POST": # TODO: streaming deserialize? @@ -175,11 +181,20 @@ class RemoteHandler: cls, args = dev.renderer.__reduce__() graph_cls = graph_class(Device[self.base_device]) rp = RemoteProperties( - real_device=dev.device, renderer=(cls.__module__, cls.__name__, args), + real_device=dev.device, renderer=(cls.__module__, cls.__name__, args), offset_supported=hasattr(dev.allocator, '_offset'), graph_supported=graph_cls is not None, graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner), - transfer_supported=hasattr(dev.allocator, '_transfer'), offset_supported=hasattr(dev.allocator, '_offset'), ) ret = repr(rp).encode() + case Event(): + if c.session == c.event_session: + session.events[c.event].set() + else: + dev.synchronize() # wait for device to finish executing previous stuff + # TODO: don't wait, just send + RemoteConnection(c.event_session.host).q(Event(c.event_session, c.event, session=c.event_session), wait=True) + case Wait(): + assert await session.events[c.event].wait() + del session.events[c.event] # do not leak memory case BufferAlloc(): assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated" session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True) @@ -190,11 +205,20 @@ class RemoteHandler: case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash]))) case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes))) case Transfer(): - dsession, ddev = self.sessions[c.dsession], Device[f"{self.base_device}:{unwrap(c.dsession).idx}"] - dbuf, sbuf = dsession.buffers[c.dbuffer_num], session.buffers[c.buffer_num] - assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}" - assert hasattr(ddev.allocator, '_transfer'), f"Device {ddev.device} doesn't support transfers" - ddev.allocator._transfer(dbuf._buf, sbuf._buf, dbuf.nbytes, dest_dev=ddev, src_dev=dev) + if c.dsession.host == unwrap(c.session).host: + dsession, ddev = self.sessions[c.dsession], Device[f"{self.base_device}:{unwrap(c.dsession).idx}"] + dbuf, sbuf = dsession.buffers[c.dbuffer_num], session.buffers[c.buffer_num] + if hasattr(ddev.allocator, '_transfer'): + assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}" + ddev.allocator._transfer(dbuf._buf, sbuf._buf, dbuf.nbytes, dest_dev=ddev, src_dev=dev) + else: + sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes))) + dbuf.copyin(data) + else: + conn = RemoteConnection(c.dsession.host) + sbuf = session.buffers[c.buffer_num] + sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes))) + conn.q(CopyIn(c.dbuffer_num, conn.req.h(data), session=c.dsession), wait=True) case ProgramAlloc(): lib = dev.compiler.compile_cached(req._h[c.datahash].decode()) session.programs[(c.name, c.datahash)] = dev.runtime(c.name, lib) @@ -256,11 +280,14 @@ class RemoteAllocator(Allocator['RemoteDevice']): assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}" dest[:] = resp def _transfer(self, dest, src, sz, src_dev, dest_dev): - if dest_dev.properties.transfer_supported and src_dev.conn == dest_dev.conn: - src_dev.q(Transfer(src, dest_dev.session, dest)) - else: - src_dev.allocator._copyout(tmp:=memoryview(bytearray(sz)), src) - dest_dev.allocator._copyin(dest, tmp) + if dest_dev.conn != src_dev.conn: + dest_dev.q(Event(src_dev.session, start_event:=next(src_dev.event_num))) + src_dev.q(Wait(start_event)) + src_dev.q(Transfer(src, dest_dev.session, dest)) + if dest_dev.conn != src_dev.conn: + src_dev.q(Event(dest_dev.session, end_event:=next(dest_dev.event_num))) + dest_dev.q(Wait(end_event)) + if DEBUG >= 2: dest_dev.conn.batch_submit() def _dyn_offset(self, opaque:int, size:int, offset:int) -> int: self.dev.q(BufferOffset(buffer_num:=next(self.dev.buffer_num), size, offset, opaque)) return buffer_num @@ -282,6 +309,8 @@ class RemoteProgram: @functools.cache class RemoteConnection: + all: dict[RemoteConnection, None] = {} # dict instead of set for deterministic ordering + def __init__(self, host:str): if DEBUG >= 1: print(f"remote with host {host}") while 1: @@ -293,15 +322,24 @@ class RemoteConnection: print(e) time.sleep(0.1) self.req: BatchRequest = BatchRequest() + RemoteConnection.all[self] = None + + def q(self, x:RemoteRequest, wait:bool=False): + self.req.q(x) + if wait: return self.batch_submit() def batch_submit(self): - data = self.req.serialize() - with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=3): - self.conn.request("POST", "/batch", data) - response = self.conn.getresponse() - assert response.status == 200, f"POST /batch failed: {response}" - ret = response.read() - self.req = BatchRequest() + conns = RemoteConnection.all.keys() + datas = {conn: conn.req.serialize() for conn in conns} + reqs, hashes, hash_datas = sum(len(c.req._q) for c in conns), sum(len(c.req._h) for c in conns), sum(len(data) for data in datas.values()) + with Timing(f"*** send {reqs:-3d} requests {hashes:-3d} hashes with len {hash_datas/1024:.2f} kB in ", enabled=DEBUG>=3): + for conn,data in datas.items(): conn.conn.request("POST", "/batch", data) + for conn in datas.keys(): + response = conn.conn.getresponse() + assert response.status == 200, f"POST /batch failed: {response}" + resp = response.read() + if conn == self: ret = resp + conn.req = BatchRequest() return ret def parse_hosts(hs:str) -> list[tuple[str, int]]|LazySeq[tuple[str, int]]: @@ -316,12 +354,13 @@ class RemoteDevice(Compiled): host, idx = RemoteDevice.devices[int(device.split(":")[1]) if ":" in device else 0] # connection is shared between sessions on the same host - self.conn: RemoteConnection = RemoteConnection(host or RemoteDevice.local_server()) + self.session: SessionKey = SessionKey(host or RemoteDevice.local_server(), idx, binascii.hexlify(os.urandom(0x10)).decode()) + self.conn: RemoteConnection = RemoteConnection(self.session.host) # state for the session - self.session = SessionKey(idx, binascii.hexlify(os.urandom(0x10)).decode()) self.buffer_num: Iterator[int] = itertools.count(0) self.graph_num: Iterator[int] = itertools.count(0) + self.event_num: Iterator[int] = itertools.count(0) self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body) if DEBUG >= 1: print(f"remote has device {self.properties.real_device}") @@ -339,9 +378,7 @@ class RemoteDevice(Compiled): def finalize(self): with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True) - def q(self, x:RemoteRequest, wait:bool=False): - self.conn.req.q(replace(x, session=self.session)) - if wait: return self.conn.batch_submit() + def q(self, x:RemoteRequest, wait:bool=False): return self.conn.q(replace(x, session=self.session), wait=wait) @functools.cache @staticmethod