add a status line to REMOTE with DEBUG=1 (#15471)

* python speedups of hot paths

* add a status line to REMOTE with DEBUG=1

* pc

* t
This commit is contained in:
George Hotz
2026-03-25 20:54:56 +08:00
committed by GitHub
parent c973b508b8
commit 25ff7146f2

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import os, mmap, array, functools, ctypes, select, contextlib, dataclasses, sys, itertools, struct, socket, subprocess, time, enum
from tinygrad.helpers import round_up, getenv, OSX, temp, ceildiv, unwrap, fetch, system, _ensure_downloads_dir
import os, mmap, array, functools, ctypes, select, contextlib, dataclasses, sys, itertools, struct, socket, subprocess, time, enum, atexit
from tinygrad.helpers import round_up, getenv, OSX, temp, ceildiv, unwrap, fetch, system, _ensure_downloads_dir, DEBUG
from tinygrad.runtime.autogen import libc, pci, vfio, iokit, corefoundation
from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuffer, hcq_filter_visible_devices
from tinygrad.runtime.support.memory import VirtMapping, AddrSpace, BumpAllocator
@@ -307,6 +307,11 @@ class RemoteMMIOInterface(MMIOInterface):
return RemoteMMIOInterface(self.dev, self.residx, size or (self.nbytes - offset), fmt or self.fmt, self.off + offset, self.rd_cmd, self.wr_cmd)
class RemotePCIDevice(PCIDevice):
_bulk_sent:int = 0
_bulk_recv:int = 0
_rpc_count:int = 0
_start_time:float = 0.0
@staticmethod
@functools.cache
def remote_sock() -> socket.socket:
@@ -317,6 +322,14 @@ class RemotePCIDevice(PCIDevice):
sock.settimeout(getenv("REMOTE_TIMEOUT", 3))
sock.connect((host, port))
sock.settimeout(None)
if DEBUG >= 1:
RemotePCIDevice._start_time = time.perf_counter()
def _print_stats():
dt = time.perf_counter() - RemotePCIDevice._start_time
sent_mb, recv_mb = RemotePCIDevice._bulk_sent / 1e6, RemotePCIDevice._bulk_recv / 1e6
print(f"remote: sent {sent_mb:,.2f} MB ({sent_mb/dt:,.2f} MB/s), recv {recv_mb:,.2f} MB ({recv_mb/dt:,.2f} MB/s), "
f"{RemotePCIDevice._rpc_count:,} roundtrips in {dt:.2f}s")
atexit.register(_print_stats)
return sock
@staticmethod
@@ -343,6 +356,7 @@ class RemotePCIDevice(PCIDevice):
else: msg, fd = RemotePCIDevice._recvall(sock, 17), None
if (resp:=struct.unpack('<BQQ', msg))[0] != 0:
raise RuntimeError(f"RPC failed: {RemotePCIDevice._recvall(sock, resp[1]).decode('utf-8') if resp[1] > 0 else 'unknown error'}")
RemotePCIDevice._rpc_count += 1
return (resp[1], resp[2]) + ((RemotePCIDevice._recvall(sock, readout_size) if readout_size > 0 else None),) + (fd,)
def __init__(self, devpref:str, pcibus:str, sock:socket.socket|None=None):
@@ -352,8 +366,10 @@ class RemotePCIDevice(PCIDevice):
self.lock_fd = System.flock_acquire(f"{devpref.lower()}_{pcibus.lower()}.lock")
def _bulk_read(self, cmd:int, idx:int, offset:int, size:int) -> bytes:
RemotePCIDevice._bulk_recv += size
return unwrap(self._rpc(self.sock, self.dev_id, cmd, offset, size, bar=idx, readout_size=size)[2])
def _bulk_write(self, cmd:int, idx:int, offset:int, data:bytes):
RemotePCIDevice._bulk_sent += len(data)
self.sock.sendall(struct.pack('<BIIQQQ', cmd, self.dev_id, idx, offset, len(data), 0) + data)
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False) -> tuple[MMIOInterface, list[int]]: