system: remote (#15318)

* system: remote

* listen

* print

* fix

* minor
This commit is contained in:
nimlgen
2026-03-17 19:25:37 +08:00
committed by GitHub
parent 69eefdca20
commit 0a641ce17d
3 changed files with 124 additions and 5 deletions

94
extra/remote/serve.py Normal file
View File

@@ -0,0 +1,94 @@
#!/usr/bin/env python3
import socket, struct, sys
from tinygrad.runtime.support.system import PCIDevice, RemoteCmd, System
from tinygrad.helpers import DEBUG
def resp(resp0=0, resp1=0, status=0): return struct.pack('<BQQ', status, resp0, resp1)
def resp_err(msg): return struct.pack('<BQQ', 1, len(err:=msg.encode()), 0) + err
discovered_devices: list[str] = []
opened_devices: dict[int, PCIDevice] = {}
mapped_bars: dict[tuple[int, int], object] = {}
sysmem_allocs: list[tuple] = []
def handle(conn, cmd, dev_id, bar, arg0, arg1, arg2):
if cmd == RemoteCmd.PROBE:
payload = conn.recv(arg1, socket.MSG_WAITALL) if arg1 > 0 else b""
filter_devices: dict[int, list[int]] = {}
for i in range(0, len(payload), 8):
mask, dev = struct.unpack('<II', payload[i:i+8])
filter_devices.setdefault(mask, []).append(dev)
base_class = None if arg0 == 0 else int(arg0)
devs = System.list_devices(arg2, tuple([(x, tuple(y)) for x,y in filter_devices.items()]), base_class)
for p in devs:
if p not in discovered_devices: discovered_devices.append(p)
data = "\n".join(f"{p[1]}:{discovered_devices.index(p)}" for p in devs).encode()
return conn.sendall(resp(len(data), len(devs)) + data)
# lazy device open
if dev_id not in opened_devices:
if dev_id >= len(discovered_devices): raise RuntimeError(f"device {dev_id} not probed")
cl, pcibus = discovered_devices[dev_id]
opened_devices[dev_id] = cl("SV", pcibus)
pci_dev = opened_devices[dev_id]
if cmd == RemoteCmd.MAP_BAR:
if (dev_id, bar) not in mapped_bars: mapped_bars[(dev_id, bar)] = pci_dev.map_bar(bar)
conn.sendall(resp(*pci_dev.bar_info(bar)))
elif cmd == RemoteCmd.CFG_READ:
conn.sendall(resp(pci_dev.read_config(arg0, arg1)))
elif cmd == RemoteCmd.CFG_WRITE:
pci_dev.write_config(arg0, arg2, arg1)
conn.sendall(resp())
elif cmd == RemoteCmd.RESIZE_BAR:
pci_dev.resize_bar(bar)
conn.sendall(resp())
elif cmd == RemoteCmd.RESET:
pci_dev.reset()
conn.sendall(resp())
elif cmd == RemoteCmd.MMIO_READ:
conn.sendmsg([resp(arg1), mapped_bars[(dev_id, bar)][arg0:arg0+arg1]])
elif cmd == RemoteCmd.MMIO_WRITE:
mapped_bars[(dev_id, bar)][arg0:arg0+arg1] = conn.recv(arg1, socket.MSG_WAITALL)
elif cmd == RemoteCmd.MAP_SYSMEM:
memview, paddrs = pci_dev.alloc_sysmem(arg0)
hdl = len(sysmem_allocs)
sysmem_allocs.append((memview, paddrs))
paddrs_bytes = struct.pack(f'<{len(paddrs)}Q', *paddrs)
conn.sendall(resp(len(paddrs_bytes), hdl) + paddrs_bytes)
elif cmd == RemoteCmd.SYSMEM_READ:
conn.sendmsg([resp(arg1), sysmem_allocs[bar][0][arg0:arg0+arg1]])
elif cmd == RemoteCmd.SYSMEM_WRITE:
sysmem_allocs[bar][0][arg0:arg0+arg1] = conn.recv(arg1, socket.MSG_WAITALL)
else: raise RuntimeError(f"unknown command {cmd}")
def serve(conn:socket.socket):
REQ = '<BIIQQQ'
while True:
hdr = conn.recv(struct.calcsize(REQ), socket.MSG_WAITALL)
if len(hdr) < struct.calcsize(REQ): raise ConnectionError("client disconnected")
cmd, dev_id, bar, arg0, arg1, arg2 = struct.unpack(REQ, hdr)
if DEBUG >= 4: print(f"cmd={RemoteCmd(cmd).name} dev={dev_id} bar={bar} arg0={arg0:#x} arg1={arg1:#x} arg2={arg2:#x}")
try: handle(conn, cmd, dev_id, bar, arg0, arg1, arg2)
except ConnectionError: raise
except Exception as e:
if cmd in {RemoteCmd.MMIO_WRITE, RemoteCmd.SYSMEM_WRITE}: raise ConnectionError(f"write failed: {e}")
print(f"ERROR: {e}")
conn.sendall(resp_err(str(e)))
if __name__ == "__main__":
port = int(sys.argv[1]) if len(sys.argv) > 1 else 6667
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind(("0.0.0.0", port))
server.listen(1)
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try: s.connect(("8.8.8.8", 80)); ip = s.getsockname()[0]
finally: s.close()
print(f"listening on {ip}:{port}")
while True:
conn, addr = server.accept()
conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
for bt in [socket.SO_SNDBUF, socket.SO_RCVBUF]: conn.setsockopt(socket.SOL_SOCKET, bt, 64 << 20)
try: serve(conn)
except ConnectionError: print("disconnected")

View File

@@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType
class MMIOInterface:
def __init__(self, addr:int, nbytes:int, fmt='B'): self.mv, self.addr, self.nbytes, self.fmt = to_mv(addr, nbytes).cast(fmt), addr, nbytes, fmt
def __len__(self): return self.nbytes // struct.calcsize(self.fmt)
def __getitem__(self, k): return (bytes(self.mv[k]) if self.fmt == 'B' else self.mv[k].tolist()) if isinstance(k, slice) else self.mv[k]
def __getitem__(self, k): return (self.mv[k] if self.fmt == 'B' else self.mv[k].tolist()) if isinstance(k, slice) else self.mv[k]
def __setitem__(self, k, v): self.mv[k] = v
def view(self, offset:int=0, size:int|None=None, fmt=None) -> MMIOInterface:
return MMIOInterface(self.addr+offset, (self.nbytes - offset) if size is None else size, fmt=fmt or self.fmt)

View File

@@ -75,6 +75,7 @@ class _System:
@functools.cache
def list_devices(self, vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None=None) -> list[tuple[type, str]]:
if getenv("REMOTE", ""): return [(RemotePCIDevice, x) for x in RemotePCIDevice.remote_list(vendor, devices, base_class)]
return [(APLRemotePCIDevice if OSX else PCIDevice, x) for x in System.pci_scan_bus(vendor, devices, base_class)]
def pci_probe_device(self, devpref:str, dev_id:int, vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None=None):
@@ -264,13 +265,15 @@ class PCIIfaceBase:
return [(p + self.pci_dev.bar_info(self.vram_bar)[0], sz) for p, sz in paddrs], AddrSpace.SYS
def map(self, b:HCQBuffer):
if not self.is_local(): raise RuntimeError(f"P2P mapping not supported for remote devices: {b.owner} -> {self.dev}")
if b.owner is not None and b.owner._is_cpu():
if not self.is_local(): raise RuntimeError(f"P2P mapping not supported for remote devices: {b.owner} -> {self.dev}")
System.lock_memory(int(b.va_addr), b.size)
paddrs, aspace = [(x, 0x1000) for x in System.system_paddrs(int(b.va_addr), round_up(b.size, 0x1000))], AddrSpace.SYS
snooped, uncached = True, True
elif (ifa:=getattr(b.owner, "iface", None)) is not None and isinstance(ifa, PCIIfaceBase):
if ifa.is_bar_small(): raise RuntimeError(f"P2P mapping not supported for small bar devices: {b.owner} -> {self.dev}")
snooped, uncached = True, b.meta.mapping.uncached
if b.meta.mapping.aspace is AddrSpace.SYS: paddrs, aspace = b.meta.mapping.paddrs, AddrSpace.SYS
else: paddrs, aspace = ifa.p2p_paddrs(b.meta.mapping.paddrs)
@@ -304,6 +307,24 @@ class RemoteMMIOInterface(MMIOInterface):
return RemoteMMIOInterface(self.dev, self.residx, size or (self.nbytes - offset), fmt or self.fmt, self.off + offset, self.rd_cmd, self.wr_cmd)
class RemotePCIDevice(PCIDevice):
@staticmethod
@functools.cache
def remote_sock() -> socket.socket:
host_port = getenv("REMOTE", "127.0.0.1:6667").split(":")
host, port = host_port[0], int(host_port[1]) if len(host_port) > 1 else 6667
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.connect((host, port))
return sock
@staticmethod
@functools.cache
def remote_list(vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None) -> list[str]:
host, port = (sock:=RemotePCIDevice.remote_sock()).getpeername()
payload = array.array('I', itertools.chain.from_iterable((m, d) for m, ds in devices for d in ds)).tobytes()
data_len, _, _, _ = RemotePCIDevice._rpc(sock, 0, RemoteCmd.PROBE, base_class or 0, len(payload), vendor, payload=payload)
return [f"remote:{host}:{port}:{d}" for d in RemotePCIDevice._recvall(sock, data_len).decode().split('\n')] if data_len else []
@staticmethod
def _recvall(sock:socket.socket, n:int) -> bytes:
data = b''
@@ -323,7 +344,7 @@ class RemotePCIDevice(PCIDevice):
return (resp[1], resp[2]) + ((RemotePCIDevice._recvall(sock, readout_size) if readout_size > 0 else None),) + (fd,)
def __init__(self, devpref:str, pcibus:str, sock:socket.socket|None=None):
self.sock, self.pcibus, self.dev_id = unwrap(sock), pcibus, int(pcibus.split(':')[-1]) if ':' in pcibus else 0
self.sock, self.pcibus, self.dev_id = sock or self.remote_sock(), pcibus, int(pcibus.split(':')[-1]) if ':' in pcibus else 0
for buft in [socket.SO_SNDBUF, socket.SO_RCVBUF]: self.sock.setsockopt(socket.SOL_SOCKET, buft, 64 << 20)
self.lock_fd = System.flock_acquire(f"{devpref.lower()}_{pcibus.lower()}.lock")
@@ -333,7 +354,11 @@ class RemotePCIDevice(PCIDevice):
def _bulk_write(self, cmd:int, idx:int, offset:int, data:bytes):
self.sock.sendall(struct.pack('<BIIQQQ', cmd, self.dev_id, idx, offset, len(data), 0) + data)
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False) -> tuple[MMIOInterface, list[int]]: raise NotImplementedError()
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False) -> tuple[MMIOInterface, list[int]]:
paddrs_len, handle, _, _ = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM, size)
paddrs = list(struct.unpack(f'<{paddrs_len // 8}Q', self._recvall(self.sock, paddrs_len)))
return RemoteMMIOInterface(self, handle, size, fmt='B', rd_cmd=RemoteCmd.SYSMEM_READ, wr_cmd=RemoteCmd.SYSMEM_WRITE), paddrs
def reset(self): self._rpc(self.sock, self.dev_id, RemoteCmd.RESET)
def read_config(self, offset:int, size:int): return self._rpc(self.sock, self.dev_id, RemoteCmd.CFG_READ, offset, size)[0]
def write_config(self, offset:int, value:int, size:int): self._rpc(self.sock, self.dev_id, RemoteCmd.CFG_WRITE, offset, size, value)