remote: fix mmio (#15347)

This commit is contained in:
nimlgen
2026-03-18 18:20:39 +08:00
committed by GitHub
parent f853371c83
commit ff004d2114
3 changed files with 16 additions and 10 deletions

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env python3
import socket, struct, sys
from tinygrad.runtime.support.system import PCIDevice, RemoteCmd, System
from tinygrad.helpers import DEBUG
from tinygrad.helpers import DEBUG, OSX
def resp(resp0=0, resp1=0, status=0): return struct.pack('<BQQ', status, resp0, resp1)
def resp_err(msg): return struct.pack('<BQQ', 1, len(err:=msg.encode()), 0) + err
@@ -47,15 +47,19 @@ def handle(conn, cmd, dev_id, bar, arg0, arg1, arg2):
pci_dev.reset()
conn.sendall(resp())
elif cmd == RemoteCmd.MMIO_READ:
conn.sendmsg([resp(arg1), mapped_bars[(dev_id, bar)][arg0:arg0+arg1]])
bar_view = mapped_bars[(dev_id, bar)]
if arg0 % 4 == 0 and arg1 == 4: conn.sendmsg([resp(arg1), struct.pack(f'<{arg1 // 4}I', bar_view.view(arg0, arg1, fmt='I')[0])])
else: conn.sendmsg([resp(arg1), bar_view[arg0:arg0+arg1]])
elif cmd == RemoteCmd.MMIO_WRITE:
mapped_bars[(dev_id, bar)][arg0:arg0+arg1] = conn.recv(arg1, socket.MSG_WAITALL)
data = conn.recv(arg1, socket.MSG_WAITALL)
bar_view = mapped_bars[(dev_id, bar)]
if arg0 % 4 == 0 and arg1 == 4: bar_view.view(arg0, arg1, fmt='I')[0] = struct.unpack(f'<{arg1 // 4}I', data)[0]
else: bar_view[arg0:arg0+arg1] = data
elif cmd == RemoteCmd.MAP_SYSMEM:
memview, paddrs = pci_dev.alloc_sysmem(arg0)
hdl = len(sysmem_allocs)
memview, paddrs = pci_dev.alloc_sysmem(arg0, contiguous=bool(arg1))
sysmem_allocs.append((memview, paddrs))
paddrs_bytes = struct.pack(f'<{len(paddrs)}Q', *paddrs)
conn.sendall(resp(len(paddrs_bytes), hdl) + paddrs_bytes)
conn.sendall(resp(len(paddrs_bytes), len(sysmem_allocs) - 1) + paddrs_bytes)
elif cmd == RemoteCmd.SYSMEM_READ:
conn.sendmsg([resp(arg1), sysmem_allocs[bar][0][arg0:arg0+arg1]])
elif cmd == RemoteCmd.SYSMEM_WRITE:
@@ -77,6 +81,8 @@ def serve(conn:socket.socket):
conn.sendall(resp_err(str(e)))
if __name__ == "__main__":
if not OSX: System.reserve_hugepages(128) # for sysmem allocations
port = int(sys.argv[1]) if len(sys.argv) > 1 else 6667
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

View File

@@ -126,7 +126,7 @@ static int map_bar(uint32_t bar, response_t *resp) {
return 0;
}
static int map_sysmem_fd(uint64_t size, response_t *resp, int *out_fd) {
static int map_sysmem_fd(uint64_t size, int contiguous, response_t *resp, int *out_fd) {
if (g_sysmem_count >= MAX_SYSMEM) return -1;
int idx = g_sysmem_count;
int fd = -1;
@@ -208,7 +208,7 @@ static void handle_client(int fd) {
case CMD_MAP_SYSMEM_FD: {
int shm_fd = -1;
resp.status = map_sysmem_fd(req.arg0, &resp, &shm_fd) ? 1 : 0;
resp.status = map_sysmem_fd(req.arg0, (int)req.arg1, &resp, &shm_fd) ? 1 : 0;
send_response(fd, &resp, shm_fd);
continue;
}

View File

@@ -355,7 +355,7 @@ class RemotePCIDevice(PCIDevice):
self.sock.sendall(struct.pack('<BIIQQQ', cmd, self.dev_id, idx, offset, len(data), 0) + data)
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False) -> tuple[MMIOInterface, list[int]]:
paddrs_len, handle, _, _ = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM, size)
paddrs_len, handle, _, _ = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM, size, int(contiguous))
paddrs = list(struct.unpack(f'<{paddrs_len // 8}Q', self._recvall(self.sock, paddrs_len)))
return RemoteMMIOInterface(self, handle, size, fmt='B', rd_cmd=RemoteCmd.SYSMEM_READ, wr_cmd=RemoteCmd.SYSMEM_WRITE), paddrs
@@ -396,7 +396,7 @@ class APLRemotePCIDevice(RemotePCIDevice):
super().__init__(devpref, "usb4", sock=sock)
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False) -> tuple[MMIOInterface, list[int]]:
mapped_size, _, _, fd = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM_FD, size, has_fd=True)
mapped_size, _, _, fd = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM_FD, size, int(contiguous), has_fd=True)
memview = MMIOInterface(FileIOInterface(fd=fd).mmap(0, mapped_size, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, 0), mapped_size, fmt='B')
# paddrs are returned as (paddr, size) pairs until a (paddr=0, size=0) terminator in the beginning of the mapping.