mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remote: fix mmio (#15347)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
import socket, struct, sys
|
||||
from tinygrad.runtime.support.system import PCIDevice, RemoteCmd, System
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.helpers import DEBUG, OSX
|
||||
|
||||
def resp(resp0=0, resp1=0, status=0): return struct.pack('<BQQ', status, resp0, resp1)
|
||||
def resp_err(msg): return struct.pack('<BQQ', 1, len(err:=msg.encode()), 0) + err
|
||||
@@ -47,15 +47,19 @@ def handle(conn, cmd, dev_id, bar, arg0, arg1, arg2):
|
||||
pci_dev.reset()
|
||||
conn.sendall(resp())
|
||||
elif cmd == RemoteCmd.MMIO_READ:
|
||||
conn.sendmsg([resp(arg1), mapped_bars[(dev_id, bar)][arg0:arg0+arg1]])
|
||||
bar_view = mapped_bars[(dev_id, bar)]
|
||||
if arg0 % 4 == 0 and arg1 == 4: conn.sendmsg([resp(arg1), struct.pack(f'<{arg1 // 4}I', bar_view.view(arg0, arg1, fmt='I')[0])])
|
||||
else: conn.sendmsg([resp(arg1), bar_view[arg0:arg0+arg1]])
|
||||
elif cmd == RemoteCmd.MMIO_WRITE:
|
||||
mapped_bars[(dev_id, bar)][arg0:arg0+arg1] = conn.recv(arg1, socket.MSG_WAITALL)
|
||||
data = conn.recv(arg1, socket.MSG_WAITALL)
|
||||
bar_view = mapped_bars[(dev_id, bar)]
|
||||
if arg0 % 4 == 0 and arg1 == 4: bar_view.view(arg0, arg1, fmt='I')[0] = struct.unpack(f'<{arg1 // 4}I', data)[0]
|
||||
else: bar_view[arg0:arg0+arg1] = data
|
||||
elif cmd == RemoteCmd.MAP_SYSMEM:
|
||||
memview, paddrs = pci_dev.alloc_sysmem(arg0)
|
||||
hdl = len(sysmem_allocs)
|
||||
memview, paddrs = pci_dev.alloc_sysmem(arg0, contiguous=bool(arg1))
|
||||
sysmem_allocs.append((memview, paddrs))
|
||||
paddrs_bytes = struct.pack(f'<{len(paddrs)}Q', *paddrs)
|
||||
conn.sendall(resp(len(paddrs_bytes), hdl) + paddrs_bytes)
|
||||
conn.sendall(resp(len(paddrs_bytes), len(sysmem_allocs) - 1) + paddrs_bytes)
|
||||
elif cmd == RemoteCmd.SYSMEM_READ:
|
||||
conn.sendmsg([resp(arg1), sysmem_allocs[bar][0][arg0:arg0+arg1]])
|
||||
elif cmd == RemoteCmd.SYSMEM_WRITE:
|
||||
@@ -77,6 +81,8 @@ def serve(conn:socket.socket):
|
||||
conn.sendall(resp_err(str(e)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not OSX: System.reserve_hugepages(128) # for sysmem allocations
|
||||
|
||||
port = int(sys.argv[1]) if len(sys.argv) > 1 else 6667
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
@@ -126,7 +126,7 @@ static int map_bar(uint32_t bar, response_t *resp) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int map_sysmem_fd(uint64_t size, response_t *resp, int *out_fd) {
|
||||
static int map_sysmem_fd(uint64_t size, int contiguous, response_t *resp, int *out_fd) {
|
||||
if (g_sysmem_count >= MAX_SYSMEM) return -1;
|
||||
int idx = g_sysmem_count;
|
||||
int fd = -1;
|
||||
@@ -208,7 +208,7 @@ static void handle_client(int fd) {
|
||||
|
||||
case CMD_MAP_SYSMEM_FD: {
|
||||
int shm_fd = -1;
|
||||
resp.status = map_sysmem_fd(req.arg0, &resp, &shm_fd) ? 1 : 0;
|
||||
resp.status = map_sysmem_fd(req.arg0, (int)req.arg1, &resp, &shm_fd) ? 1 : 0;
|
||||
send_response(fd, &resp, shm_fd);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -355,7 +355,7 @@ class RemotePCIDevice(PCIDevice):
|
||||
self.sock.sendall(struct.pack('<BIIQQQ', cmd, self.dev_id, idx, offset, len(data), 0) + data)
|
||||
|
||||
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False) -> tuple[MMIOInterface, list[int]]:
|
||||
paddrs_len, handle, _, _ = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM, size)
|
||||
paddrs_len, handle, _, _ = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM, size, int(contiguous))
|
||||
paddrs = list(struct.unpack(f'<{paddrs_len // 8}Q', self._recvall(self.sock, paddrs_len)))
|
||||
return RemoteMMIOInterface(self, handle, size, fmt='B', rd_cmd=RemoteCmd.SYSMEM_READ, wr_cmd=RemoteCmd.SYSMEM_WRITE), paddrs
|
||||
|
||||
@@ -396,7 +396,7 @@ class APLRemotePCIDevice(RemotePCIDevice):
|
||||
super().__init__(devpref, "usb4", sock=sock)
|
||||
|
||||
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False) -> tuple[MMIOInterface, list[int]]:
|
||||
mapped_size, _, _, fd = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM_FD, size, has_fd=True)
|
||||
mapped_size, _, _, fd = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM_FD, size, int(contiguous), has_fd=True)
|
||||
memview = MMIOInterface(FileIOInterface(fd=fd).mmap(0, mapped_size, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, 0), mapped_size, fmt='B')
|
||||
|
||||
# paddrs are returned as (paddr, size) pairs until a (paddr=0, size=0) terminator in the beginning of the mapping.
|
||||
|
||||
Reference in New Issue
Block a user