mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
system: remote (#15318)
* system: remote * listen * print * fix * minor
This commit is contained in:
94
extra/remote/serve.py
Normal file
94
extra/remote/serve.py
Normal file
@@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python3
|
||||
import socket, struct, sys
|
||||
from tinygrad.runtime.support.system import PCIDevice, RemoteCmd, System
|
||||
from tinygrad.helpers import DEBUG
|
||||
|
||||
def resp(resp0=0, resp1=0, status=0): return struct.pack('<BQQ', status, resp0, resp1)
|
||||
def resp_err(msg): return struct.pack('<BQQ', 1, len(err:=msg.encode()), 0) + err
|
||||
|
||||
discovered_devices: list[str] = []
|
||||
opened_devices: dict[int, PCIDevice] = {}
|
||||
mapped_bars: dict[tuple[int, int], object] = {}
|
||||
sysmem_allocs: list[tuple] = []
|
||||
|
||||
def handle(conn, cmd, dev_id, bar, arg0, arg1, arg2):
|
||||
if cmd == RemoteCmd.PROBE:
|
||||
payload = conn.recv(arg1, socket.MSG_WAITALL) if arg1 > 0 else b""
|
||||
filter_devices: dict[int, list[int]] = {}
|
||||
for i in range(0, len(payload), 8):
|
||||
mask, dev = struct.unpack('<II', payload[i:i+8])
|
||||
filter_devices.setdefault(mask, []).append(dev)
|
||||
base_class = None if arg0 == 0 else int(arg0)
|
||||
devs = System.list_devices(arg2, tuple([(x, tuple(y)) for x,y in filter_devices.items()]), base_class)
|
||||
for p in devs:
|
||||
if p not in discovered_devices: discovered_devices.append(p)
|
||||
data = "\n".join(f"{p[1]}:{discovered_devices.index(p)}" for p in devs).encode()
|
||||
return conn.sendall(resp(len(data), len(devs)) + data)
|
||||
|
||||
# lazy device open
|
||||
if dev_id not in opened_devices:
|
||||
if dev_id >= len(discovered_devices): raise RuntimeError(f"device {dev_id} not probed")
|
||||
cl, pcibus = discovered_devices[dev_id]
|
||||
opened_devices[dev_id] = cl("SV", pcibus)
|
||||
pci_dev = opened_devices[dev_id]
|
||||
|
||||
if cmd == RemoteCmd.MAP_BAR:
|
||||
if (dev_id, bar) not in mapped_bars: mapped_bars[(dev_id, bar)] = pci_dev.map_bar(bar)
|
||||
conn.sendall(resp(*pci_dev.bar_info(bar)))
|
||||
elif cmd == RemoteCmd.CFG_READ:
|
||||
conn.sendall(resp(pci_dev.read_config(arg0, arg1)))
|
||||
elif cmd == RemoteCmd.CFG_WRITE:
|
||||
pci_dev.write_config(arg0, arg2, arg1)
|
||||
conn.sendall(resp())
|
||||
elif cmd == RemoteCmd.RESIZE_BAR:
|
||||
pci_dev.resize_bar(bar)
|
||||
conn.sendall(resp())
|
||||
elif cmd == RemoteCmd.RESET:
|
||||
pci_dev.reset()
|
||||
conn.sendall(resp())
|
||||
elif cmd == RemoteCmd.MMIO_READ:
|
||||
conn.sendmsg([resp(arg1), mapped_bars[(dev_id, bar)][arg0:arg0+arg1]])
|
||||
elif cmd == RemoteCmd.MMIO_WRITE:
|
||||
mapped_bars[(dev_id, bar)][arg0:arg0+arg1] = conn.recv(arg1, socket.MSG_WAITALL)
|
||||
elif cmd == RemoteCmd.MAP_SYSMEM:
|
||||
memview, paddrs = pci_dev.alloc_sysmem(arg0)
|
||||
hdl = len(sysmem_allocs)
|
||||
sysmem_allocs.append((memview, paddrs))
|
||||
paddrs_bytes = struct.pack(f'<{len(paddrs)}Q', *paddrs)
|
||||
conn.sendall(resp(len(paddrs_bytes), hdl) + paddrs_bytes)
|
||||
elif cmd == RemoteCmd.SYSMEM_READ:
|
||||
conn.sendmsg([resp(arg1), sysmem_allocs[bar][0][arg0:arg0+arg1]])
|
||||
elif cmd == RemoteCmd.SYSMEM_WRITE:
|
||||
sysmem_allocs[bar][0][arg0:arg0+arg1] = conn.recv(arg1, socket.MSG_WAITALL)
|
||||
else: raise RuntimeError(f"unknown command {cmd}")
|
||||
|
||||
def serve(conn:socket.socket):
|
||||
REQ = '<BIIQQQ'
|
||||
while True:
|
||||
hdr = conn.recv(struct.calcsize(REQ), socket.MSG_WAITALL)
|
||||
if len(hdr) < struct.calcsize(REQ): raise ConnectionError("client disconnected")
|
||||
cmd, dev_id, bar, arg0, arg1, arg2 = struct.unpack(REQ, hdr)
|
||||
if DEBUG >= 4: print(f"cmd={RemoteCmd(cmd).name} dev={dev_id} bar={bar} arg0={arg0:#x} arg1={arg1:#x} arg2={arg2:#x}")
|
||||
try: handle(conn, cmd, dev_id, bar, arg0, arg1, arg2)
|
||||
except ConnectionError: raise
|
||||
except Exception as e:
|
||||
if cmd in {RemoteCmd.MMIO_WRITE, RemoteCmd.SYSMEM_WRITE}: raise ConnectionError(f"write failed: {e}")
|
||||
print(f"ERROR: {e}")
|
||||
conn.sendall(resp_err(str(e)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(sys.argv[1]) if len(sys.argv) > 1 else 6667
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind(("0.0.0.0", port))
|
||||
server.listen(1)
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try: s.connect(("8.8.8.8", 80)); ip = s.getsockname()[0]
|
||||
finally: s.close()
|
||||
print(f"listening on {ip}:{port}")
|
||||
while True:
|
||||
conn, addr = server.accept()
|
||||
conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
for bt in [socket.SO_SNDBUF, socket.SO_RCVBUF]: conn.setsockopt(socket.SOL_SOCKET, bt, 64 << 20)
|
||||
try: serve(conn)
|
||||
except ConnectionError: print("disconnected")
|
||||
@@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType
|
||||
class MMIOInterface:
|
||||
def __init__(self, addr:int, nbytes:int, fmt='B'): self.mv, self.addr, self.nbytes, self.fmt = to_mv(addr, nbytes).cast(fmt), addr, nbytes, fmt
|
||||
def __len__(self): return self.nbytes // struct.calcsize(self.fmt)
|
||||
def __getitem__(self, k): return (bytes(self.mv[k]) if self.fmt == 'B' else self.mv[k].tolist()) if isinstance(k, slice) else self.mv[k]
|
||||
def __getitem__(self, k): return (self.mv[k] if self.fmt == 'B' else self.mv[k].tolist()) if isinstance(k, slice) else self.mv[k]
|
||||
def __setitem__(self, k, v): self.mv[k] = v
|
||||
def view(self, offset:int=0, size:int|None=None, fmt=None) -> MMIOInterface:
|
||||
return MMIOInterface(self.addr+offset, (self.nbytes - offset) if size is None else size, fmt=fmt or self.fmt)
|
||||
|
||||
@@ -75,6 +75,7 @@ class _System:
|
||||
|
||||
@functools.cache
|
||||
def list_devices(self, vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None=None) -> list[tuple[type, str]]:
|
||||
if getenv("REMOTE", ""): return [(RemotePCIDevice, x) for x in RemotePCIDevice.remote_list(vendor, devices, base_class)]
|
||||
return [(APLRemotePCIDevice if OSX else PCIDevice, x) for x in System.pci_scan_bus(vendor, devices, base_class)]
|
||||
|
||||
def pci_probe_device(self, devpref:str, dev_id:int, vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None=None):
|
||||
@@ -264,13 +265,15 @@ class PCIIfaceBase:
|
||||
return [(p + self.pci_dev.bar_info(self.vram_bar)[0], sz) for p, sz in paddrs], AddrSpace.SYS
|
||||
|
||||
def map(self, b:HCQBuffer):
|
||||
if not self.is_local(): raise RuntimeError(f"P2P mapping not supported for remote devices: {b.owner} -> {self.dev}")
|
||||
|
||||
if b.owner is not None and b.owner._is_cpu():
|
||||
if not self.is_local(): raise RuntimeError(f"P2P mapping not supported for remote devices: {b.owner} -> {self.dev}")
|
||||
|
||||
System.lock_memory(int(b.va_addr), b.size)
|
||||
paddrs, aspace = [(x, 0x1000) for x in System.system_paddrs(int(b.va_addr), round_up(b.size, 0x1000))], AddrSpace.SYS
|
||||
snooped, uncached = True, True
|
||||
elif (ifa:=getattr(b.owner, "iface", None)) is not None and isinstance(ifa, PCIIfaceBase):
|
||||
if ifa.is_bar_small(): raise RuntimeError(f"P2P mapping not supported for small bar devices: {b.owner} -> {self.dev}")
|
||||
|
||||
snooped, uncached = True, b.meta.mapping.uncached
|
||||
if b.meta.mapping.aspace is AddrSpace.SYS: paddrs, aspace = b.meta.mapping.paddrs, AddrSpace.SYS
|
||||
else: paddrs, aspace = ifa.p2p_paddrs(b.meta.mapping.paddrs)
|
||||
@@ -304,6 +307,24 @@ class RemoteMMIOInterface(MMIOInterface):
|
||||
return RemoteMMIOInterface(self.dev, self.residx, size or (self.nbytes - offset), fmt or self.fmt, self.off + offset, self.rd_cmd, self.wr_cmd)
|
||||
|
||||
class RemotePCIDevice(PCIDevice):
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def remote_sock() -> socket.socket:
|
||||
host_port = getenv("REMOTE", "127.0.0.1:6667").split(":")
|
||||
host, port = host_port[0], int(host_port[1]) if len(host_port) > 1 else 6667
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
sock.connect((host, port))
|
||||
return sock
|
||||
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def remote_list(vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None) -> list[str]:
|
||||
host, port = (sock:=RemotePCIDevice.remote_sock()).getpeername()
|
||||
payload = array.array('I', itertools.chain.from_iterable((m, d) for m, ds in devices for d in ds)).tobytes()
|
||||
data_len, _, _, _ = RemotePCIDevice._rpc(sock, 0, RemoteCmd.PROBE, base_class or 0, len(payload), vendor, payload=payload)
|
||||
return [f"remote:{host}:{port}:{d}" for d in RemotePCIDevice._recvall(sock, data_len).decode().split('\n')] if data_len else []
|
||||
|
||||
@staticmethod
|
||||
def _recvall(sock:socket.socket, n:int) -> bytes:
|
||||
data = b''
|
||||
@@ -323,7 +344,7 @@ class RemotePCIDevice(PCIDevice):
|
||||
return (resp[1], resp[2]) + ((RemotePCIDevice._recvall(sock, readout_size) if readout_size > 0 else None),) + (fd,)
|
||||
|
||||
def __init__(self, devpref:str, pcibus:str, sock:socket.socket|None=None):
|
||||
self.sock, self.pcibus, self.dev_id = unwrap(sock), pcibus, int(pcibus.split(':')[-1]) if ':' in pcibus else 0
|
||||
self.sock, self.pcibus, self.dev_id = sock or self.remote_sock(), pcibus, int(pcibus.split(':')[-1]) if ':' in pcibus else 0
|
||||
for buft in [socket.SO_SNDBUF, socket.SO_RCVBUF]: self.sock.setsockopt(socket.SOL_SOCKET, buft, 64 << 20)
|
||||
|
||||
self.lock_fd = System.flock_acquire(f"{devpref.lower()}_{pcibus.lower()}.lock")
|
||||
@@ -333,7 +354,11 @@ class RemotePCIDevice(PCIDevice):
|
||||
def _bulk_write(self, cmd:int, idx:int, offset:int, data:bytes):
|
||||
self.sock.sendall(struct.pack('<BIIQQQ', cmd, self.dev_id, idx, offset, len(data), 0) + data)
|
||||
|
||||
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False) -> tuple[MMIOInterface, list[int]]: raise NotImplementedError()
|
||||
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False) -> tuple[MMIOInterface, list[int]]:
|
||||
paddrs_len, handle, _, _ = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM, size)
|
||||
paddrs = list(struct.unpack(f'<{paddrs_len // 8}Q', self._recvall(self.sock, paddrs_len)))
|
||||
return RemoteMMIOInterface(self, handle, size, fmt='B', rd_cmd=RemoteCmd.SYSMEM_READ, wr_cmd=RemoteCmd.SYSMEM_WRITE), paddrs
|
||||
|
||||
def reset(self): self._rpc(self.sock, self.dev_id, RemoteCmd.RESET)
|
||||
def read_config(self, offset:int, size:int): return self._rpc(self.sock, self.dev_id, RemoteCmd.CFG_READ, offset, size)[0]
|
||||
def write_config(self, offset:int, value:int, size:int): self._rpc(self.sock, self.dev_id, RemoteCmd.CFG_WRITE, offset, size, value)
|
||||
|
||||
Reference in New Issue
Block a user