mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -4,10 +4,10 @@ import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, co
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQSignal, HCQProgram, FileIOInterface
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface, BumpAllocator, hcq_filter_visible_devices
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface, BumpAllocator, hcq_filter_visible_devices, hcq_profile
|
||||
from tinygrad.uop.ops import sint
|
||||
from tinygrad.device import Compiled, BufferSpec
|
||||
from tinygrad.helpers import getenv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, lo32, hi32, colored, prod, ContextVar
|
||||
from tinygrad.helpers import getenv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, lo32, hi32, colored, prod, ContextVar, TracingKey
|
||||
from tinygrad.helpers import VIZ, ceildiv, unwrap
|
||||
from tinygrad.renderer.cstyle import HIPRenderer, HIPCCRenderer
|
||||
from tinygrad.renderer.llvmir import AMDLLVMRenderer
|
||||
@@ -506,6 +506,10 @@ class AMDCopyQueue(HWQueue):
|
||||
*data64_le(signal.timestamp_addr))
|
||||
return self
|
||||
|
||||
def write(self, b:HCQBuffer, val:sint, b64:bool=False):
|
||||
self.q(self.sdma.SDMA_OP_WRITE, *data64_le(b.va_addr), 1 if b64 else 0, lo32(val), *([hi32(val)] if b64 else []))
|
||||
return self
|
||||
|
||||
def bind(self, dev:AMDDevice):
|
||||
if not getenv("AMD_SDMA_BIND", 0) or not dev.is_am(): return
|
||||
|
||||
@@ -649,6 +653,21 @@ class AMDAllocator(HCQAllocator['AMDDevice']):
|
||||
|
||||
def _map(self, buf:HCQBuffer): return self.dev.iface.map(buf._base if buf._base is not None else buf)
|
||||
|
||||
def _copyout(self, dest:memoryview, src:HCQBuffer):
|
||||
if not self.dev.is_usb(): return super()._copyout(dest, src)
|
||||
if not self.dev.iface.pci_dev.usb.usb.is_custom: return super()._copyout(dest, src)
|
||||
self.dev.synchronize()
|
||||
|
||||
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=TracingKey(f"{self.dev.device} -> TINY", ret=dest.nbytes), enabled=PROFILE,
|
||||
dev_suff="SDMA:0"):
|
||||
for i in range(0, dest.nbytes, cp_size:=self.b[0].size):
|
||||
self.dev.iface.pci_dev.usb.scsi_read_arm(lsize:=min(cp_size, dest.nbytes - i))
|
||||
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
||||
.copy(self.b[0], src.offset(i), lsize) \
|
||||
.write(self.dev.iface.cq_buf.offset(12), 0) \
|
||||
.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
|
||||
dest.cast('B')[i:i+lsize] = self.b[0].cpu_view().view(size=lsize, fmt='B')[:]
|
||||
|
||||
@dataclass
|
||||
class AMDQueueDesc:
|
||||
ring: MMIOInterface
|
||||
@@ -899,6 +918,7 @@ class USBIface(PCIIface):
|
||||
# special regions
|
||||
self.copy_bufs = [self._dma_region(ctrl_addr=0xf000, sys_addr=0x200000, size=0x80000)]
|
||||
self.sys_buf, self.sys_next_off = self._dma_region(ctrl_addr=0xa000, sys_addr=0x820000, size=0x1000), 0x800
|
||||
self.cq_buf = self._dma_region(ctrl_addr=0xb800, sys_addr=0x822000, size=0x1000)
|
||||
|
||||
def _dma_region(self, ctrl_addr, sys_addr, size):
|
||||
region = self.dev_impl.mm.map_range(vaddr:=self.dev_impl.mm.alloc_vaddr(size=size), size, [(sys_addr, size)], aspace=AddrSpace.SYS, uncached=True)
|
||||
@@ -910,7 +930,7 @@ class USBIface(PCIIface):
|
||||
return self.sys_buf.offset(self.sys_next_off - size, size)
|
||||
|
||||
# force devmem
|
||||
return super().alloc(size, host=host, uncached=uncached, cpu_access=cpu_access, contiguous=contiguous, force_devmem=True, **kwargs)
|
||||
return super().alloc(size, host=False, uncached=uncached, cpu_access=cpu_access, contiguous=contiguous, force_devmem=True, **kwargs)
|
||||
|
||||
def create_queue(self, queue_type, ring, gart, rptr, wptr, eop_buffer=None, cwsr_buffer=None, ctl_stack_size=0, ctx_save_restore_size=0,
|
||||
xcc_id=0, idx=0):
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
import ctypes, struct, dataclasses, array, itertools, time
|
||||
from typing import Sequence
|
||||
from tinygrad.runtime.autogen import libusb
|
||||
from tinygrad.helpers import DEBUG, to_mv, round_up, OSX, getenv
|
||||
from tinygrad.helpers import DEBUG, to_mv, round_up, OSX, getenv, ceildiv
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
|
||||
def alloc_cbuffer(sz:int) -> tuple[ctypes.Array, memoryview]: return (buf:=(ctypes.c_ubyte * sz)()), to_mv(ctypes.addressof(buf), sz)
|
||||
|
||||
class USB3:
|
||||
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31, use_bot=False):
|
||||
self.vendor, self.dev = vendor, dev
|
||||
self.ep_data_in, self.ep_stat_in, self.ep_data_out, self.ep_cmd_out = ep_data_in, ep_stat_in, ep_data_out, ep_cmd_out
|
||||
self.max_streams, self.use_bot = max_streams, use_bot
|
||||
self._transferred = ctypes.c_int(0)
|
||||
self._bulk_in_buf, self._bulk_in_mv = alloc_cbuffer(4 << 20)
|
||||
self._bulk_out_buf, self._bulk_out_mv = alloc_cbuffer(4 << 20)
|
||||
self.ctx = ctypes.POINTER(libusb.struct_libusb_context)()
|
||||
|
||||
if libusb.libusb_init(ctypes.byref(self.ctx)): raise RuntimeError("libusb_init failed")
|
||||
@@ -82,30 +87,17 @@ class USB3:
|
||||
elif tr.contents.status != 0xFF: raise RuntimeError(f"EP 0x{tr.contents.endpoint:02X} error: {tr.contents.status}")
|
||||
|
||||
def _bulk_out(self, ep: int, payload: bytes, timeout: int = 1000):
|
||||
transferred = ctypes.c_int(0)
|
||||
rc = libusb.libusb_bulk_transfer(
|
||||
self.handle,
|
||||
ep,
|
||||
(ctypes.c_ubyte * len(payload))(*payload),
|
||||
len(payload),
|
||||
ctypes.byref(transferred),
|
||||
timeout,
|
||||
)
|
||||
if len(payload) > len(self._bulk_out_mv): self._bulk_out_buf, self._bulk_out_mv = alloc_cbuffer(len(payload))
|
||||
self._bulk_out_mv[:len(payload)] = payload
|
||||
rc = libusb.libusb_bulk_transfer(self.handle, ep, self._bulk_out_buf, len(payload), ctypes.byref(self._transferred), timeout)
|
||||
assert rc == 0, f"bulk OUT 0x{ep:02X} failed: {rc}"
|
||||
assert transferred.value == len(payload), f"bulk OUT short write on 0x{ep:02X}: {transferred.value}/{len(payload)} bytes"
|
||||
assert self._transferred.value == len(payload), f"bulk OUT short write on 0x{ep:02X}: {self._transferred.value}/{len(payload)} bytes"
|
||||
|
||||
def _bulk_in(self, ep: int, length: int, timeout: int = 1000) -> bytes:
|
||||
buf, transferred = (ctypes.c_ubyte * length)(), ctypes.c_int(0)
|
||||
rc = libusb.libusb_bulk_transfer(
|
||||
self.handle,
|
||||
ep,
|
||||
buf,
|
||||
length,
|
||||
ctypes.byref(transferred),
|
||||
timeout,
|
||||
)
|
||||
def _bulk_in(self, ep: int, length: int, timeout: int = 1000) -> memoryview:
|
||||
if length > len(self._bulk_in_mv): self._bulk_in_buf, self._bulk_in_mv = alloc_cbuffer(length)
|
||||
rc = libusb.libusb_bulk_transfer(self.handle, ep, self._bulk_in_buf, length, ctypes.byref(self._transferred), timeout)
|
||||
assert rc == 0, f"bulk IN 0x{ep:02X} failed: {rc}"
|
||||
return bytes(buf[:transferred.value])
|
||||
return self._bulk_in_mv[:self._transferred.value]
|
||||
|
||||
def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]:
|
||||
idata, odata = idata or [0] * len(cdbs), odata or [None] * len(cdbs)
|
||||
@@ -126,7 +118,7 @@ class USB3:
|
||||
|
||||
# DAT
|
||||
if dir_in:
|
||||
results.append(self._bulk_in(self.ep_data_in, rlen))
|
||||
results.append(bytes(self._bulk_in(self.ep_data_in, rlen)))
|
||||
else:
|
||||
if send_data is not None:
|
||||
self._bulk_out(self.ep_data_out, send_data)
|
||||
@@ -183,6 +175,10 @@ class CustomASM24Controller:
|
||||
self._pci_cacheable: list[tuple[int, int]] = []
|
||||
self._pci_cache: dict[int, int|None] = {}
|
||||
|
||||
# Pre-allocate buffers for _f0_out (12 bytes) and _f0_in (8 bytes)
|
||||
self._f0_out_buf, self._f0_out_mv = alloc_cbuffer(12)
|
||||
self._f0_in_buf, _ = alloc_cbuffer(8)
|
||||
|
||||
# Verify custom firmware is running and PCIe link is up (LTSSM=0x78).
|
||||
ltssm = self.read(0xB450, 1)[0]
|
||||
if ltssm != 0x78: raise RuntimeError(f"PCIe link not up (LTSSM=0x{ltssm:02X}), custom firmware not ready")
|
||||
@@ -190,22 +186,14 @@ class CustomASM24Controller:
|
||||
# === PCIe TLP via 0xF0 vendor command ===
|
||||
|
||||
def _f0_out(self, fmt_type:int, byte_en:int, address:int, value:int, mode:int=0):
|
||||
"""Send 0xF0 OUT control transfer: configure TLP engine. 12-byte DATA_OUT = addr_lo[4 LE] + addr_hi[4 LE] + value[4 LE]."""
|
||||
wval = fmt_type | (byte_en << 8)
|
||||
widx = mode & 0x03
|
||||
payload = struct.pack('<III', address & 0xFFFFFFFF, address >> 32, value)
|
||||
buf = (ctypes.c_ubyte * 12)(*payload)
|
||||
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF0, wval, widx, buf, 12, 5000)
|
||||
struct.pack_into('<III', self._f0_out_mv, 0, address & 0xFFFFFFFF, address >> 32, value)
|
||||
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF0, fmt_type | (byte_en << 8), mode & 0x03, self._f0_out_buf, 12, 5000)
|
||||
assert ret == 12, f"F0 OUT failed: {ret}"
|
||||
|
||||
def _f0_in(self) -> tuple[int, int, int]:
|
||||
"""Read 0xF0 IN: 8 bytes = data[4 LE] + cpl_hdr[2] + compl_status[1] + ret_status[1]. Returns (data, compl_status, ret_status)."""
|
||||
buf = (ctypes.c_ubyte * 8)()
|
||||
ret = libusb.libusb_control_transfer(self.usb.handle, 0xC0, 0xF0, 0, 0, buf, 8, 5000)
|
||||
ret = libusb.libusb_control_transfer(self.usb.handle, 0xC0, 0xF0, 0, 0, self._f0_in_buf, 8, 5000)
|
||||
assert ret == 8, f"F0 IN failed: {ret}"
|
||||
data = struct.unpack('<I', bytes(buf[0:4]))[0]
|
||||
cpl_status = (buf[4] >> 5) & 0x7 # completion status from CPL_HDR_HI bits [7:5]
|
||||
return data, cpl_status, buf[7]
|
||||
return struct.unpack_from('<I', self._f0_in_buf, 0)[0], (self._f0_in_buf[4] >> 5) & 0x7, self._f0_in_buf[7]
|
||||
|
||||
def _is_pci_cacheable(self, addr:int) -> bool: return any(x <= addr <= x + sz for x, sz in self._pci_cacheable)
|
||||
|
||||
@@ -265,10 +253,9 @@ class CustomASM24Controller:
|
||||
result = b''
|
||||
for off in range(0, length, 0xFF):
|
||||
chunk = min(0xFF, length - off)
|
||||
buf = (ctypes.c_ubyte * chunk)()
|
||||
ret = libusb.libusb_control_transfer(self.usb.handle, 0xC0, 0xE4, base_addr + off, 0, buf, chunk, 1000)
|
||||
ret = libusb.libusb_control_transfer(self.usb.handle, 0xC0, 0xE4, base_addr + off, 0, self._f0_out_buf, chunk, 1000)
|
||||
assert ret == chunk, f"read(0x{base_addr + off:04X}, {chunk}) failed: {ret}"
|
||||
result += bytes(buf[:ret])
|
||||
result += bytes(self._f0_out_buf[:ret])
|
||||
return result[:length]
|
||||
|
||||
def write(self, base_addr:int, data:bytes, **kwargs):
|
||||
@@ -288,6 +275,12 @@ class CustomASM24Controller:
|
||||
assert ret >= 0, f"F2 setup failed: {ret}"
|
||||
self.usb._bulk_out(0x02, buf_padded)
|
||||
|
||||
def scsi_read_arm(self, size:int):
|
||||
windex = (ceildiv(size, 0x4000) & 0xFF) << 8
|
||||
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF2, (ceildiv(size, 512) & 0x7FFF) | 0x8000, windex, None, 0, 1000)
|
||||
assert ret >= 0, f"F2 read arm failed: {ret}"
|
||||
|
||||
def scsi_read(self, size:int) -> memoryview: return self.usb._bulk_in(0x81, round_up(size, 512), timeout=10000)[:size]
|
||||
|
||||
class ASM24Controller:
|
||||
def __init__(self, usb:USB3|None=None):
|
||||
@@ -429,6 +422,7 @@ class USBMMIOInterface(MMIOInterface):
|
||||
def _acc(self, off, sz, data=None):
|
||||
if data is None: # read op
|
||||
if not self.pcimem:
|
||||
if self.addr == 0xf000 and hasattr(self.usb, 'scsi_read'): return self.usb.scsi_read(sz)
|
||||
return int.from_bytes(self.usb.read(self.addr + off, sz), "little") if sz == self.el_sz else self.usb.read(self.addr + off, sz)
|
||||
|
||||
# Fast path: streaming PCIe read if controller supports it
|
||||
|
||||
Reference in New Issue
Block a user