From 2b01ca59dd1f0ff770e6eddca3148ff361e2e420 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 7 Apr 2026 13:45:41 +0800 Subject: [PATCH] USB driver for custom ASM firmware (#15597) * USB driver for custom ASM firmware * timeout * fix mypy * pcie mem read * flip in f/w * one tx * litle endian * autodetect custom * mock bypass * lint * clean --- test/mockgpu/usb.py | 3 +- tinygrad/runtime/support/system.py | 8 +- tinygrad/runtime/support/usb.py | 131 ++++++++++++++++++++++++++++- 3 files changed, 135 insertions(+), 7 deletions(-) diff --git a/test/mockgpu/usb.py b/test/mockgpu/usb.py index fe1d46663a..cd59a9ce43 100644 --- a/test/mockgpu/usb.py +++ b/test/mockgpu/usb.py @@ -202,7 +202,8 @@ class MockASM24State: return None class MockUSB3: - def __init__(self, *args, **kwargs): pass + def __init__(self, *args, **kwargs): + self.product, self.is_custom = "", False def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]: assert _mock_usb_state is not None idata, odata = idata or [0] * len(cdbs), odata or [None] * len(cdbs) diff --git a/tinygrad/runtime/support/system.py b/tinygrad/runtime/support/system.py index 3f659e6808..e7858a5b73 100644 --- a/tinygrad/runtime/support/system.py +++ b/tinygrad/runtime/support/system.py @@ -4,7 +4,7 @@ from tinygrad.helpers import round_up, getenv, OSX, temp, ceildiv, unwrap, fetch from tinygrad.runtime.autogen import libc, pci, vfio, iokit, corefoundation from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuffer, hcq_filter_visible_devices from tinygrad.runtime.support.memory import VirtMapping, AddrSpace, BumpAllocator -from tinygrad.runtime.support.usb import ASM24Controller, USBMMIOInterface +from tinygrad.runtime.support.usb import USB3, CustomASM24Controller, ASM24Controller, USBMMIOInterface MAP_FIXED, MAP_FIXED_NOREPLACE = 0x10, 0x100000 MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400 @@ -82,7 +82,7 @@ class _System: cl, pcibus = hcq_filter_visible_devices(self.list_devices(vendor, devices, base_class))[dev_id] return cl(devpref, pcibus) - def pci_setup_usb_bars(self, usb:ASM24Controller, gpu_bus:int, mem_base:int, pref_mem_base:int) -> dict[int, tuple[int, int]]: + def pci_setup_usb_bars(self, usb:CustomASM24Controller|ASM24Controller, gpu_bus:int, mem_base:int, pref_mem_base:int) -> dict[int, tuple[int, int]]: for bus in range(gpu_bus): # All 3 values must be written at the same time. buses = (0 << 0) | ((bus+1) << 8) | ((gpu_bus) << 16) @@ -213,7 +213,9 @@ class PCIDevice: class USBPCIDevice(PCIDevice): def __init__(self, devpref:str, pcibus:str): self.lock_fd = System.flock_acquire(f"{devpref.lower()}_{pcibus.lower()}.lock") - self.usb = ASM24Controller() + usb = USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04) + if DEBUG >= 1: print(f"am usb: product string: {usb.product!r}") + self.usb: CustomASM24Controller | ASM24Controller = CustomASM24Controller(usb) if usb.is_custom else ASM24Controller(usb) self.pcibus, self._bar_info = pcibus, System.pci_setup_usb_bars(self.usb, gpu_bus=4, mem_base=0x10000000, pref_mem_base=(32 << 30)) self.sram = BumpAllocator(size=0x80000, wrap=False) # asm24 controller sram diff --git a/tinygrad/runtime/support/usb.py b/tinygrad/runtime/support/usb.py index b9de1b7e0e..27307e6a1e 100644 --- a/tinygrad/runtime/support/usb.py +++ b/tinygrad/runtime/support/usb.py @@ -1,4 +1,4 @@ -import ctypes, struct, dataclasses, array, itertools +import ctypes, struct, dataclasses, array, itertools, time from typing import Sequence from tinygrad.runtime.autogen import libusb from tinygrad.helpers import DEBUG, to_mv, round_up, OSX, getenv @@ -17,6 +17,15 @@ class USB3: self.handle = libusb.libusb_open_device_with_vid_pid(self.ctx, self.vendor, self.dev) if not self.handle: raise RuntimeError(f"device {self.vendor:04x}:{self.dev:04x} not found. sudo required?") + # Read product string descriptor + _buf = (ctypes.c_ubyte * 256)() + _desc = libusb.struct_libusb_device_descriptor() + libusb.libusb_get_device_descriptor(libusb.libusb_get_device(self.handle), ctypes.byref(_desc)) + _ret = libusb.libusb_get_string_descriptor_ascii(self.handle, _desc.iProduct, _buf, 256) + self.product = bytes(_buf[:max(_ret, 0)]).decode("ascii", errors="replace") if _ret > 0 else "" + self.is_custom = self.product.startswith("custom") + if self.is_custom: self.use_bot = use_bot = True + # Detach kernel driver if needed if libusb.libusb_kernel_driver_active(self.handle, 0): libusb.libusb_detach_kernel_driver(self.handle, 0) @@ -168,9 +177,121 @@ class ReadOp: addr:int; size:int # noqa: E702 @dataclasses.dataclass(frozen=True) class ScsiWriteOp: data:bytes; lba:int=0 # noqa: E702 +class CustomASM24Controller: + def __init__(self, usb:USB3|None=None): + self.usb = usb or USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04, use_bot=True) + self._pci_cacheable: list[tuple[int, int]] = [] + self._pci_cache: dict[int, int|None] = {} + + # Verify custom firmware is running and PCIe link is up (LTSSM=0x78). + ltssm = self.read(0xB450, 1)[0] + if ltssm != 0x78: raise RuntimeError(f"PCIe link not up (LTSSM=0x{ltssm:02X}), custom firmware not ready") + + # === PCIe TLP via 0xF0 vendor command === + + def _f0_out(self, fmt_type:int, byte_en:int, address:int, value:int, mode:int=0): + """Send 0xF0 OUT control transfer: configure TLP engine. 12-byte DATA_OUT = addr_lo[4 LE] + addr_hi[4 LE] + value[4 LE].""" + wval = fmt_type | (byte_en << 8) + widx = mode & 0x03 + payload = struct.pack('> 32, value) + buf = (ctypes.c_ubyte * 12)(*payload) + ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF0, wval, widx, buf, 12, 5000) + assert ret == 12, f"F0 OUT failed: {ret}" + + def _f0_in(self) -> tuple[int, int, int]: + """Read 0xF0 IN: 8 bytes = data[4 LE] + cpl_hdr[2] + compl_status[1] + ret_status[1]. Returns (data, compl_status, ret_status).""" + buf = (ctypes.c_ubyte * 8)() + ret = libusb.libusb_control_transfer(self.usb.handle, 0xC0, 0xF0, 0, 0, buf, 8, 5000) + assert ret == 8, f"F0 IN failed: {ret}" + data = struct.unpack('> 5) & 0x7 # completion status from CPL_HDR_HI bits [7:5] + return data, cpl_status, buf[7] + + def _is_pci_cacheable(self, addr:int) -> bool: return any(x <= addr <= x + sz for x, sz in self._pci_cacheable) + + def pcie_request(self, fmt_type:int, address:int, value:int|None=None, size:int=4, cnt:int=10): + if fmt_type == 0x60 and size == 4 and self._is_pci_cacheable(address) and self._pci_cache.get(address) == value: return + assert size > 0 and size <= 4, f"Invalid size {size}" + if DEBUG >= 5: print("pcie_request", hex(fmt_type), hex(address), value, size) + + offset = address & 0x3 + byte_en = ((1 << size) - 1) << offset + self._pci_cache[address] = value if size == 4 and fmt_type == 0x60 else None + + self._f0_out(fmt_type, byte_en, address & ~0x3, (value << (8 * offset)) if value is not None else 0) + + # Fast path: memory writes and messages don't return completions (same logic as ASM24Controller). + if ((fmt_type & 0b11011111) == 0b01000000) or ((fmt_type & 0b10111000) == 0b00110000): return + + # Read TLPs and config writes: read completion via 0xF0 IN. Retry on error/timeout. + data, cpl_status, ret_status = self._f0_in() + if ret_status != 0: + time.sleep(0.001) # TODO: this sleep is very picky + if cnt > 0: + return self.pcie_request(fmt_type, address, value, size, cnt=cnt-1) + raise RuntimeError(f"TLP error after retries: ret_status={ret_status}, address={address:#x}") + + if cpl_status: + status_map = {0b001: f"Unsupported Request: {address:#x}", 0b100: "Completer Abort", 0b010: "Config Retry"} + raise RuntimeError(f"TLP completion status: {status_map.get(cpl_status, f'Reserved (0b{cpl_status:03b})')}") + + if value is None: return (data >> (8 * offset)) & ((1 << (8 * size)) - 1) + + def pcie_cfg_req(self, byte_addr:int, bus:int=1, dev:int=0, fn:int=0, value:int|None=None, size:int=4): + assert byte_addr >> 12 == 0 and bus >> 8 == 0 and dev >> 5 == 0 and fn >> 3 == 0 + fmt_type = (0x44 if value is not None else 0x4) | int(bus > 0) + address = (bus << 24) | (dev << 19) | (fn << 16) | (byte_addr & 0xfff) + return self.pcie_request(fmt_type, address, value, size) + + def pcie_mem_req(self, address:int, value:int|None=None, size:int=4): + return self.pcie_request(0x60 if value is not None else 0x20, address, value, size) + + def pcie_mem_write(self, address:int, values:list[int], size:int): + """Streaming PCIe memory write via 0xF0 mode 1 + bulk OUT. Data is little-endian dwords on the wire.""" + if not values: return + self._f0_out(0x60, 0x0F, address, len(values), mode=1) + self.usb._bulk_out(0x02, struct.pack(f'<{len(values)}I', *values)) + + def pcie_mem_read(self, address:int, nbytes:int) -> bytes: + """Streaming PCIe memory read via 0xF0 mode 2 + bulk IN. Returns little-endian bytes.""" + assert nbytes % 4 == 0, f"pcie_mem_read requires 4-byte aligned size, got {nbytes}" + self._f0_out(0x20, 0x0F, address, nbytes // 4, mode=2) + return self.usb._bulk_in(0x81, nbytes, timeout=30000) + + # === XDATA read/write (0xE4/0xE5 vendor control transfers) === + + def read(self, base_addr:int, length:int, **kwargs) -> bytes: + """Read from chip XDATA via vendor control IN (bRequest=0xE4). wValue=addr, wLength=size.""" + result = b'' + for off in range(0, length, 0xFF): + chunk = min(0xFF, length - off) + buf = (ctypes.c_ubyte * chunk)() + ret = libusb.libusb_control_transfer(self.usb.handle, 0xC0, 0xE4, base_addr + off, 0, buf, chunk, 1000) + assert ret == chunk, f"read(0x{base_addr + off:04X}, {chunk}) failed: {ret}" + result += bytes(buf[:ret]) + return result[:length] + + def write(self, base_addr:int, data:bytes, **kwargs): + """Write to chip XDATA via vendor control OUT (bRequest=0xE5). wValue=addr, wIndex=val.""" + for off, val in enumerate(data): + ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xE5, base_addr + off, val, None, 0, 1000) + assert ret >= 0, f"write(0x{base_addr + off:04X}, 0x{val:02X}) failed: {ret}" + + def scsi_write(self, buf:bytes, lba:int=0): + """Write to SRAM via 0xF2 vendor command + bulk OUT.""" + buf_padded = buf + b'\x00' * (round_up(len(buf), 512) - len(buf)) + sectors = len(buf_padded) // 512 + num_slots = round_up(len(buf_padded), 0x4000) // 0x4000 # 16KB per slot + # 0xF2 OUT: wValue=sectors, wIndex=start_slot|(num_slots<<8) + windex = (num_slots & 0xFF) << 8 + ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF2, sectors, windex, None, 0, 1000) + assert ret >= 0, f"F2 setup failed: {ret}" + self.usb._bulk_out(0x02, buf_padded) + + class ASM24Controller: - def __init__(self): - self.usb = USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04, use_bot=bool(getenv("USE_BOT", 0))) + def __init__(self, usb:USB3|None=None): + self.usb = usb or USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04, use_bot=bool(getenv("USE_BOT", 0))) self._cache: dict[int, int|None] = {} self._pci_cacheable: list[tuple[int, int]] = [] self._pci_cache: dict[int, int|None] = {} @@ -310,6 +431,10 @@ class USBMMIOInterface(MMIOInterface): if not self.pcimem: return int.from_bytes(self.usb.read(self.addr + off, sz), "little") if sz == self.el_sz else self.usb.read(self.addr + off, sz) + # Fast path: streaming PCIe read if controller supports it + if hasattr(self.usb, 'pcie_mem_read') and sz >= 4 and sz % 4 == 0: + return self.usb.pcie_mem_read(self.addr + off, sz) + acc, acc_size = self._acc_size(sz) return bytes(array.array(acc, [self._acc_one(off + i * acc_size, acc_size) for i in range(sz // acc_size)]))