From 0a641ce17ddd8203333d591269f06fdad7494736 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Tue, 17 Mar 2026 19:25:37 +0800 Subject: [PATCH] system: remote (#15318) * system: remote * listen * print * fix * minor --- extra/remote/serve.py | 94 ++++++++++++++++++++++++++++++ tinygrad/runtime/support/hcq.py | 2 +- tinygrad/runtime/support/system.py | 33 +++++++++-- 3 files changed, 124 insertions(+), 5 deletions(-) create mode 100644 extra/remote/serve.py diff --git a/extra/remote/serve.py b/extra/remote/serve.py new file mode 100644 index 0000000000..1cd2a62eec --- /dev/null +++ b/extra/remote/serve.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +import socket, struct, sys +from tinygrad.runtime.support.system import PCIDevice, RemoteCmd, System +from tinygrad.helpers import DEBUG + +def resp(resp0=0, resp1=0, status=0): return struct.pack(' 0 else b"" + filter_devices: dict[int, list[int]] = {} + for i in range(0, len(payload), 8): + mask, dev = struct.unpack('= len(discovered_devices): raise RuntimeError(f"device {dev_id} not probed") + cl, pcibus = discovered_devices[dev_id] + opened_devices[dev_id] = cl("SV", pcibus) + pci_dev = opened_devices[dev_id] + + if cmd == RemoteCmd.MAP_BAR: + if (dev_id, bar) not in mapped_bars: mapped_bars[(dev_id, bar)] = pci_dev.map_bar(bar) + conn.sendall(resp(*pci_dev.bar_info(bar))) + elif cmd == RemoteCmd.CFG_READ: + conn.sendall(resp(pci_dev.read_config(arg0, arg1))) + elif cmd == RemoteCmd.CFG_WRITE: + pci_dev.write_config(arg0, arg2, arg1) + conn.sendall(resp()) + elif cmd == RemoteCmd.RESIZE_BAR: + pci_dev.resize_bar(bar) + conn.sendall(resp()) + elif cmd == RemoteCmd.RESET: + pci_dev.reset() + conn.sendall(resp()) + elif cmd == RemoteCmd.MMIO_READ: + conn.sendmsg([resp(arg1), mapped_bars[(dev_id, bar)][arg0:arg0+arg1]]) + elif cmd == RemoteCmd.MMIO_WRITE: + mapped_bars[(dev_id, bar)][arg0:arg0+arg1] = conn.recv(arg1, socket.MSG_WAITALL) + elif cmd == RemoteCmd.MAP_SYSMEM: + memview, paddrs = pci_dev.alloc_sysmem(arg0) + hdl = len(sysmem_allocs) + sysmem_allocs.append((memview, paddrs)) + paddrs_bytes = struct.pack(f'<{len(paddrs)}Q', *paddrs) + conn.sendall(resp(len(paddrs_bytes), hdl) + paddrs_bytes) + elif cmd == RemoteCmd.SYSMEM_READ: + conn.sendmsg([resp(arg1), sysmem_allocs[bar][0][arg0:arg0+arg1]]) + elif cmd == RemoteCmd.SYSMEM_WRITE: + sysmem_allocs[bar][0][arg0:arg0+arg1] = conn.recv(arg1, socket.MSG_WAITALL) + else: raise RuntimeError(f"unknown command {cmd}") + +def serve(conn:socket.socket): + REQ = '= 4: print(f"cmd={RemoteCmd(cmd).name} dev={dev_id} bar={bar} arg0={arg0:#x} arg1={arg1:#x} arg2={arg2:#x}") + try: handle(conn, cmd, dev_id, bar, arg0, arg1, arg2) + except ConnectionError: raise + except Exception as e: + if cmd in {RemoteCmd.MMIO_WRITE, RemoteCmd.SYSMEM_WRITE}: raise ConnectionError(f"write failed: {e}") + print(f"ERROR: {e}") + conn.sendall(resp_err(str(e))) + +if __name__ == "__main__": + port = int(sys.argv[1]) if len(sys.argv) > 1 else 6667 + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind(("0.0.0.0", port)) + server.listen(1) + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: s.connect(("8.8.8.8", 80)); ip = s.getsockname()[0] + finally: s.close() + print(f"listening on {ip}:{port}") + while True: + conn, addr = server.accept() + conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + for bt in [socket.SO_SNDBUF, socket.SO_RCVBUF]: conn.setsockopt(socket.SOL_SOCKET, bt, 64 << 20) + try: serve(conn) + except ConnectionError: print("disconnected") diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index 5cb6e7c4e9..e1e2e6500d 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType class MMIOInterface: def __init__(self, addr:int, nbytes:int, fmt='B'): self.mv, self.addr, self.nbytes, self.fmt = to_mv(addr, nbytes).cast(fmt), addr, nbytes, fmt def __len__(self): return self.nbytes // struct.calcsize(self.fmt) - def __getitem__(self, k): return (bytes(self.mv[k]) if self.fmt == 'B' else self.mv[k].tolist()) if isinstance(k, slice) else self.mv[k] + def __getitem__(self, k): return (self.mv[k] if self.fmt == 'B' else self.mv[k].tolist()) if isinstance(k, slice) else self.mv[k] def __setitem__(self, k, v): self.mv[k] = v def view(self, offset:int=0, size:int|None=None, fmt=None) -> MMIOInterface: return MMIOInterface(self.addr+offset, (self.nbytes - offset) if size is None else size, fmt=fmt or self.fmt) diff --git a/tinygrad/runtime/support/system.py b/tinygrad/runtime/support/system.py index f53958ff78..3361cae133 100644 --- a/tinygrad/runtime/support/system.py +++ b/tinygrad/runtime/support/system.py @@ -75,6 +75,7 @@ class _System: @functools.cache def list_devices(self, vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None=None) -> list[tuple[type, str]]: + if getenv("REMOTE", ""): return [(RemotePCIDevice, x) for x in RemotePCIDevice.remote_list(vendor, devices, base_class)] return [(APLRemotePCIDevice if OSX else PCIDevice, x) for x in System.pci_scan_bus(vendor, devices, base_class)] def pci_probe_device(self, devpref:str, dev_id:int, vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None=None): @@ -264,13 +265,15 @@ class PCIIfaceBase: return [(p + self.pci_dev.bar_info(self.vram_bar)[0], sz) for p, sz in paddrs], AddrSpace.SYS def map(self, b:HCQBuffer): - if not self.is_local(): raise RuntimeError(f"P2P mapping not supported for remote devices: {b.owner} -> {self.dev}") - if b.owner is not None and b.owner._is_cpu(): + if not self.is_local(): raise RuntimeError(f"P2P mapping not supported for remote devices: {b.owner} -> {self.dev}") + System.lock_memory(int(b.va_addr), b.size) paddrs, aspace = [(x, 0x1000) for x in System.system_paddrs(int(b.va_addr), round_up(b.size, 0x1000))], AddrSpace.SYS snooped, uncached = True, True elif (ifa:=getattr(b.owner, "iface", None)) is not None and isinstance(ifa, PCIIfaceBase): + if ifa.is_bar_small(): raise RuntimeError(f"P2P mapping not supported for small bar devices: {b.owner} -> {self.dev}") + snooped, uncached = True, b.meta.mapping.uncached if b.meta.mapping.aspace is AddrSpace.SYS: paddrs, aspace = b.meta.mapping.paddrs, AddrSpace.SYS else: paddrs, aspace = ifa.p2p_paddrs(b.meta.mapping.paddrs) @@ -304,6 +307,24 @@ class RemoteMMIOInterface(MMIOInterface): return RemoteMMIOInterface(self.dev, self.residx, size or (self.nbytes - offset), fmt or self.fmt, self.off + offset, self.rd_cmd, self.wr_cmd) class RemotePCIDevice(PCIDevice): + @staticmethod + @functools.cache + def remote_sock() -> socket.socket: + host_port = getenv("REMOTE", "127.0.0.1:6667").split(":") + host, port = host_port[0], int(host_port[1]) if len(host_port) > 1 else 6667 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + sock.connect((host, port)) + return sock + + @staticmethod + @functools.cache + def remote_list(vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None) -> list[str]: + host, port = (sock:=RemotePCIDevice.remote_sock()).getpeername() + payload = array.array('I', itertools.chain.from_iterable((m, d) for m, ds in devices for d in ds)).tobytes() + data_len, _, _, _ = RemotePCIDevice._rpc(sock, 0, RemoteCmd.PROBE, base_class or 0, len(payload), vendor, payload=payload) + return [f"remote:{host}:{port}:{d}" for d in RemotePCIDevice._recvall(sock, data_len).decode().split('\n')] if data_len else [] + @staticmethod def _recvall(sock:socket.socket, n:int) -> bytes: data = b'' @@ -323,7 +344,7 @@ class RemotePCIDevice(PCIDevice): return (resp[1], resp[2]) + ((RemotePCIDevice._recvall(sock, readout_size) if readout_size > 0 else None),) + (fd,) def __init__(self, devpref:str, pcibus:str, sock:socket.socket|None=None): - self.sock, self.pcibus, self.dev_id = unwrap(sock), pcibus, int(pcibus.split(':')[-1]) if ':' in pcibus else 0 + self.sock, self.pcibus, self.dev_id = sock or self.remote_sock(), pcibus, int(pcibus.split(':')[-1]) if ':' in pcibus else 0 for buft in [socket.SO_SNDBUF, socket.SO_RCVBUF]: self.sock.setsockopt(socket.SOL_SOCKET, buft, 64 << 20) self.lock_fd = System.flock_acquire(f"{devpref.lower()}_{pcibus.lower()}.lock") @@ -333,7 +354,11 @@ class RemotePCIDevice(PCIDevice): def _bulk_write(self, cmd:int, idx:int, offset:int, data:bytes): self.sock.sendall(struct.pack(' tuple[MMIOInterface, list[int]]: raise NotImplementedError() + def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False) -> tuple[MMIOInterface, list[int]]: + paddrs_len, handle, _, _ = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM, size) + paddrs = list(struct.unpack(f'<{paddrs_len // 8}Q', self._recvall(self.sock, paddrs_len))) + return RemoteMMIOInterface(self, handle, size, fmt='B', rd_cmd=RemoteCmd.SYSMEM_READ, wr_cmd=RemoteCmd.SYSMEM_WRITE), paddrs + def reset(self): self._rpc(self.sock, self.dev_id, RemoteCmd.RESET) def read_config(self, offset:int, size:int): return self._rpc(self.sock, self.dev_id, RemoteCmd.CFG_READ, offset, size)[0] def write_config(self, offset:int, value:int, size:int): self._rpc(self.sock, self.dev_id, RemoteCmd.CFG_WRITE, offset, size, value)