usb fast copyout (#15677)

* usb

* fix usb
This commit is contained in:
nimlgen
2026-04-10 21:04:49 +03:00
committed by GitHub
parent 0d5cdc9600
commit 58646f9569
2 changed files with 55 additions and 41 deletions

View File

@@ -4,10 +4,10 @@ import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, co
assert sys.platform != 'win32'
from dataclasses import dataclass
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQSignal, HCQProgram, FileIOInterface
from tinygrad.runtime.support.hcq import MMIOInterface, BumpAllocator, hcq_filter_visible_devices
from tinygrad.runtime.support.hcq import MMIOInterface, BumpAllocator, hcq_filter_visible_devices, hcq_profile
from tinygrad.uop.ops import sint
from tinygrad.device import Compiled, BufferSpec
from tinygrad.helpers import getenv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, lo32, hi32, colored, prod, ContextVar
from tinygrad.helpers import getenv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, lo32, hi32, colored, prod, ContextVar, TracingKey
from tinygrad.helpers import VIZ, ceildiv, unwrap
from tinygrad.renderer.cstyle import HIPRenderer, HIPCCRenderer
from tinygrad.renderer.llvmir import AMDLLVMRenderer
@@ -506,6 +506,10 @@ class AMDCopyQueue(HWQueue):
*data64_le(signal.timestamp_addr))
return self
def write(self, b:HCQBuffer, val:sint, b64:bool=False):
self.q(self.sdma.SDMA_OP_WRITE, *data64_le(b.va_addr), 1 if b64 else 0, lo32(val), *([hi32(val)] if b64 else []))
return self
def bind(self, dev:AMDDevice):
if not getenv("AMD_SDMA_BIND", 0) or not dev.is_am(): return
@@ -649,6 +653,21 @@ class AMDAllocator(HCQAllocator['AMDDevice']):
def _map(self, buf:HCQBuffer): return self.dev.iface.map(buf._base if buf._base is not None else buf)
def _copyout(self, dest:memoryview, src:HCQBuffer):
if not self.dev.is_usb(): return super()._copyout(dest, src)
if not self.dev.iface.pci_dev.usb.usb.is_custom: return super()._copyout(dest, src)
self.dev.synchronize()
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=TracingKey(f"{self.dev.device} -> TINY", ret=dest.nbytes), enabled=PROFILE,
dev_suff="SDMA:0"):
for i in range(0, dest.nbytes, cp_size:=self.b[0].size):
self.dev.iface.pci_dev.usb.scsi_read_arm(lsize:=min(cp_size, dest.nbytes - i))
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
.copy(self.b[0], src.offset(i), lsize) \
.write(self.dev.iface.cq_buf.offset(12), 0) \
.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
dest.cast('B')[i:i+lsize] = self.b[0].cpu_view().view(size=lsize, fmt='B')[:]
@dataclass
class AMDQueueDesc:
ring: MMIOInterface
@@ -899,6 +918,7 @@ class USBIface(PCIIface):
# special regions
self.copy_bufs = [self._dma_region(ctrl_addr=0xf000, sys_addr=0x200000, size=0x80000)]
self.sys_buf, self.sys_next_off = self._dma_region(ctrl_addr=0xa000, sys_addr=0x820000, size=0x1000), 0x800
self.cq_buf = self._dma_region(ctrl_addr=0xb800, sys_addr=0x822000, size=0x1000)
def _dma_region(self, ctrl_addr, sys_addr, size):
region = self.dev_impl.mm.map_range(vaddr:=self.dev_impl.mm.alloc_vaddr(size=size), size, [(sys_addr, size)], aspace=AddrSpace.SYS, uncached=True)
@@ -910,7 +930,7 @@ class USBIface(PCIIface):
return self.sys_buf.offset(self.sys_next_off - size, size)
# force devmem
return super().alloc(size, host=host, uncached=uncached, cpu_access=cpu_access, contiguous=contiguous, force_devmem=True, **kwargs)
return super().alloc(size, host=False, uncached=uncached, cpu_access=cpu_access, contiguous=contiguous, force_devmem=True, **kwargs)
def create_queue(self, queue_type, ring, gart, rptr, wptr, eop_buffer=None, cwsr_buffer=None, ctl_stack_size=0, ctx_save_restore_size=0,
xcc_id=0, idx=0):

View File

@@ -1,14 +1,19 @@
import ctypes, struct, dataclasses, array, itertools, time
from typing import Sequence
from tinygrad.runtime.autogen import libusb
from tinygrad.helpers import DEBUG, to_mv, round_up, OSX, getenv
from tinygrad.helpers import DEBUG, to_mv, round_up, OSX, getenv, ceildiv
from tinygrad.runtime.support.hcq import MMIOInterface
def alloc_cbuffer(sz:int) -> tuple[ctypes.Array, memoryview]: return (buf:=(ctypes.c_ubyte * sz)()), to_mv(ctypes.addressof(buf), sz)
class USB3:
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31, use_bot=False):
self.vendor, self.dev = vendor, dev
self.ep_data_in, self.ep_stat_in, self.ep_data_out, self.ep_cmd_out = ep_data_in, ep_stat_in, ep_data_out, ep_cmd_out
self.max_streams, self.use_bot = max_streams, use_bot
self._transferred = ctypes.c_int(0)
self._bulk_in_buf, self._bulk_in_mv = alloc_cbuffer(4 << 20)
self._bulk_out_buf, self._bulk_out_mv = alloc_cbuffer(4 << 20)
self.ctx = ctypes.POINTER(libusb.struct_libusb_context)()
if libusb.libusb_init(ctypes.byref(self.ctx)): raise RuntimeError("libusb_init failed")
@@ -82,30 +87,17 @@ class USB3:
elif tr.contents.status != 0xFF: raise RuntimeError(f"EP 0x{tr.contents.endpoint:02X} error: {tr.contents.status}")
def _bulk_out(self, ep: int, payload: bytes, timeout: int = 1000):
transferred = ctypes.c_int(0)
rc = libusb.libusb_bulk_transfer(
self.handle,
ep,
(ctypes.c_ubyte * len(payload))(*payload),
len(payload),
ctypes.byref(transferred),
timeout,
)
if len(payload) > len(self._bulk_out_mv): self._bulk_out_buf, self._bulk_out_mv = alloc_cbuffer(len(payload))
self._bulk_out_mv[:len(payload)] = payload
rc = libusb.libusb_bulk_transfer(self.handle, ep, self._bulk_out_buf, len(payload), ctypes.byref(self._transferred), timeout)
assert rc == 0, f"bulk OUT 0x{ep:02X} failed: {rc}"
assert transferred.value == len(payload), f"bulk OUT short write on 0x{ep:02X}: {transferred.value}/{len(payload)} bytes"
assert self._transferred.value == len(payload), f"bulk OUT short write on 0x{ep:02X}: {self._transferred.value}/{len(payload)} bytes"
def _bulk_in(self, ep: int, length: int, timeout: int = 1000) -> bytes:
buf, transferred = (ctypes.c_ubyte * length)(), ctypes.c_int(0)
rc = libusb.libusb_bulk_transfer(
self.handle,
ep,
buf,
length,
ctypes.byref(transferred),
timeout,
)
def _bulk_in(self, ep: int, length: int, timeout: int = 1000) -> memoryview:
if length > len(self._bulk_in_mv): self._bulk_in_buf, self._bulk_in_mv = alloc_cbuffer(length)
rc = libusb.libusb_bulk_transfer(self.handle, ep, self._bulk_in_buf, length, ctypes.byref(self._transferred), timeout)
assert rc == 0, f"bulk IN 0x{ep:02X} failed: {rc}"
return bytes(buf[:transferred.value])
return self._bulk_in_mv[:self._transferred.value]
def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]:
idata, odata = idata or [0] * len(cdbs), odata or [None] * len(cdbs)
@@ -126,7 +118,7 @@ class USB3:
# DAT
if dir_in:
results.append(self._bulk_in(self.ep_data_in, rlen))
results.append(bytes(self._bulk_in(self.ep_data_in, rlen)))
else:
if send_data is not None:
self._bulk_out(self.ep_data_out, send_data)
@@ -183,6 +175,10 @@ class CustomASM24Controller:
self._pci_cacheable: list[tuple[int, int]] = []
self._pci_cache: dict[int, int|None] = {}
# Pre-allocate buffers for _f0_out (12 bytes) and _f0_in (8 bytes)
self._f0_out_buf, self._f0_out_mv = alloc_cbuffer(12)
self._f0_in_buf, _ = alloc_cbuffer(8)
# Verify custom firmware is running and PCIe link is up (LTSSM=0x78).
ltssm = self.read(0xB450, 1)[0]
if ltssm != 0x78: raise RuntimeError(f"PCIe link not up (LTSSM=0x{ltssm:02X}), custom firmware not ready")
@@ -190,22 +186,14 @@ class CustomASM24Controller:
# === PCIe TLP via 0xF0 vendor command ===
def _f0_out(self, fmt_type:int, byte_en:int, address:int, value:int, mode:int=0):
"""Send 0xF0 OUT control transfer: configure TLP engine. 12-byte DATA_OUT = addr_lo[4 LE] + addr_hi[4 LE] + value[4 LE]."""
wval = fmt_type | (byte_en << 8)
widx = mode & 0x03
payload = struct.pack('<III', address & 0xFFFFFFFF, address >> 32, value)
buf = (ctypes.c_ubyte * 12)(*payload)
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF0, wval, widx, buf, 12, 5000)
struct.pack_into('<III', self._f0_out_mv, 0, address & 0xFFFFFFFF, address >> 32, value)
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF0, fmt_type | (byte_en << 8), mode & 0x03, self._f0_out_buf, 12, 5000)
assert ret == 12, f"F0 OUT failed: {ret}"
def _f0_in(self) -> tuple[int, int, int]:
"""Read 0xF0 IN: 8 bytes = data[4 LE] + cpl_hdr[2] + compl_status[1] + ret_status[1]. Returns (data, compl_status, ret_status)."""
buf = (ctypes.c_ubyte * 8)()
ret = libusb.libusb_control_transfer(self.usb.handle, 0xC0, 0xF0, 0, 0, buf, 8, 5000)
ret = libusb.libusb_control_transfer(self.usb.handle, 0xC0, 0xF0, 0, 0, self._f0_in_buf, 8, 5000)
assert ret == 8, f"F0 IN failed: {ret}"
data = struct.unpack('<I', bytes(buf[0:4]))[0]
cpl_status = (buf[4] >> 5) & 0x7 # completion status from CPL_HDR_HI bits [7:5]
return data, cpl_status, buf[7]
return struct.unpack_from('<I', self._f0_in_buf, 0)[0], (self._f0_in_buf[4] >> 5) & 0x7, self._f0_in_buf[7]
def _is_pci_cacheable(self, addr:int) -> bool: return any(x <= addr <= x + sz for x, sz in self._pci_cacheable)
@@ -265,10 +253,9 @@ class CustomASM24Controller:
result = b''
for off in range(0, length, 0xFF):
chunk = min(0xFF, length - off)
buf = (ctypes.c_ubyte * chunk)()
ret = libusb.libusb_control_transfer(self.usb.handle, 0xC0, 0xE4, base_addr + off, 0, buf, chunk, 1000)
ret = libusb.libusb_control_transfer(self.usb.handle, 0xC0, 0xE4, base_addr + off, 0, self._f0_out_buf, chunk, 1000)
assert ret == chunk, f"read(0x{base_addr + off:04X}, {chunk}) failed: {ret}"
result += bytes(buf[:ret])
result += bytes(self._f0_out_buf[:ret])
return result[:length]
def write(self, base_addr:int, data:bytes, **kwargs):
@@ -288,6 +275,12 @@ class CustomASM24Controller:
assert ret >= 0, f"F2 setup failed: {ret}"
self.usb._bulk_out(0x02, buf_padded)
def scsi_read_arm(self, size:int):
windex = (ceildiv(size, 0x4000) & 0xFF) << 8
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF2, (ceildiv(size, 512) & 0x7FFF) | 0x8000, windex, None, 0, 1000)
assert ret >= 0, f"F2 read arm failed: {ret}"
def scsi_read(self, size:int) -> memoryview: return self.usb._bulk_in(0x81, round_up(size, 512), timeout=10000)[:size]
class ASM24Controller:
def __init__(self, usb:USB3|None=None):
@@ -429,6 +422,7 @@ class USBMMIOInterface(MMIOInterface):
def _acc(self, off, sz, data=None):
if data is None: # read op
if not self.pcimem:
if self.addr == 0xf000 and hasattr(self.usb, 'scsi_read'): return self.usb.scsi_read(sz)
return int.from_bytes(self.usb.read(self.addr + off, sz), "little") if sz == self.el_sz else self.usb.read(self.addr + off, sz)
# Fast path: streaming PCIe read if controller supports it