Files
tinygrad/test/mockgpu/mockgpu.py
nimlgen 6de15dc480 mockam usb (#14916)
* mockam usb

* f

* win

* x

* x
2026-02-21 23:05:54 +03:00

115 lines
4.2 KiB
Python

import ctypes, ctypes.util, time, os, builtins, fcntl
from tinygrad.helpers import getenv
from tinygrad.runtime.support.hcq import FileIOInterface
from test.mockgpu.nv.nvdriver import NVDriver
from test.mockgpu.amd.amddriver import AMDDriver
from test.mockgpu.am.amdriver import AMDriver, AMUSBDriver
start = time.perf_counter()
# *** ioctl lib ***
libc = ctypes.CDLL(ctypes.util.find_library("c"))
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
libc.mmap.restype = ctypes.c_void_p
_amd_iface = getenv("AMD_IFACE", "")
drivers = [NVDriver(), AMDriver() if _amd_iface == "PCI" else (AMUSBDriver() if _amd_iface == "USB" else AMDDriver())]
tracked_fds = {}
original_memoryview = builtins.memoryview
class TrackedMemoryView:
def __init__(self, data, rcb, wcb):
self.mv = original_memoryview(data)
self.rcb, self.wcb = rcb, wcb
def __getitem__(self, index):
self.rcb(self.mv, index)
return self.mv[index]
def __setitem__(self, index, value):
self.mv[index] = value
self.wcb(self.mv, index)
def cast(self, new_type, **kwargs):
self.mv = self.mv.cast('B').cast(new_type, **kwargs)
return self
@property
def nbytes(self): return self.mv.nbytes
def __len__(self): return len(self.mv)
def __repr__(self): return repr(self.mv)
def _memoryview(cls, mem):
if isinstance(mem, int) or isinstance(mem, ctypes.Array):
addr = ctypes.addressof(mem) if isinstance(mem, ctypes.Array) else mem
for d in drivers:
for st,en,rcb,wcb in d.tracked_addresses:
if st <= addr <= en: return TrackedMemoryView(mem, rcb, wcb)
return original_memoryview(mem)
class _MockMemoryviewMeta(type):
def __instancecheck__(cls, instance): return isinstance(instance, (original_memoryview, TrackedMemoryView))
builtins.memoryview = _MockMemoryviewMeta("memoryview", (), {'__new__': _memoryview}) # type: ignore
def _open(path, flags):
for d in drivers:
for x in d.tracked_files:
if path == x.path:
virtfd = d.open(path, flags, 0o777, x)
tracked_fds[virtfd.fd] = virtfd
return virtfd.fd
return os.open(path, flags, 0o777) if os.path.exists(path) else None
class MockFileIOInterface(FileIOInterface):
def __init__(self, path:str="", flags:int=os.O_RDONLY, fd:int|None=None):
self.path = path
self.fd = fd or _open(path, flags)
def __del__(self):
if self.fd in tracked_fds:
tracked_fds[self.fd].close(self.fd)
tracked_fds.pop(self.fd)
else: os.close(self.fd)
def ioctl(self, request, arg):
if self.fd in tracked_fds:
return tracked_fds[self.fd].ioctl(self.fd, request, ctypes.addressof(arg))
return fcntl.ioctl(self.fd, request, arg)
def mmap(self, start, sz, prot, flags, offset):
if self.fd in tracked_fds:
return tracked_fds[self.fd].mmap(start, sz, prot, flags, self.fd, offset)
return libc.mmap(start, sz, prot, flags, self.fd, offset)
def read(self, size=None, binary=False, offset=None):
if self.fd in tracked_fds:
if offset is not None: tracked_fds[self.fd].seek(offset)
return tracked_fds[self.fd].read_contents(size)
if binary: raise NotImplementedError()
with open(self.fd, "rb" if binary else "r", closefd=False) as file:
if file.tell() >= os.fstat(self.fd).st_size: file.seek(0)
return file.read(size)
def listdir(self):
if self.fd in tracked_fds:
return tracked_fds[self.fd].list_contents()
return os.listdir(self.path)
def write(self, content, binary=False, offset=None):
if self.fd in tracked_fds:
if offset is not None: tracked_fds[self.fd].seek(offset)
return tracked_fds[self.fd].write_contents(content)
raise NotImplementedError()
def seek(self, offset):
if self.fd in tracked_fds:
tracked_fds[self.fd].seek(offset)
else:
os.lseek(self.fd, offset, os.SEEK_CUR)
@staticmethod
def anon_mmap(start, sz, prot, flags, offset):
return FileIOInterface._mmap(start, sz, prot, flags & ~0x4a000, -1, offset) # strip MAP_LOCKED|MAP_POPULATE|MAP_HUGETLB
@staticmethod
def exists(path): return _open(path, os.O_RDONLY) is not None
@staticmethod
def readlink(path): raise NotImplementedError()
@staticmethod
def eventfd(initval, flags=None): NotImplementedError()