diff --git a/test/mockgpu/usb.py b/test/mockgpu/usb.py index 24c7691f0b..a7fba8790c 100644 --- a/test/mockgpu/usb.py +++ b/test/mockgpu/usb.py @@ -10,4 +10,7 @@ class MockUSB: def pcie_mem_req(self, address, value=None, size=1): if value is None: return int.from_bytes(self.mem[address:address+size], "little") - else: self.mem[address:address+size] = value.to_bytes(size, "little") \ No newline at end of file + else: self.mem[address:address+size] = value.to_bytes(size, "little") + + def pcie_mem_write(self, address, values, size): + for i, value in enumerate(values): self.pcie_mem_req(address + i * size, value, size) diff --git a/tinygrad/runtime/support/usb.py b/tinygrad/runtime/support/usb.py index 00e3347884..83c6f6b9de 100644 --- a/tinygrad/runtime/support/usb.py +++ b/tinygrad/runtime/support/usb.py @@ -1,11 +1,11 @@ -import ctypes, struct, dataclasses, array +import ctypes, struct, dataclasses, array, itertools from typing import Sequence from tinygrad.runtime.autogen import libusb from tinygrad.helpers import DEBUG from tinygrad.runtime.support.hcq import MMIOInterface class USB3: - def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=16, max_read_len:int=4096): + def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31, max_read_len:int=4096): self.vendor, self.dev = vendor, dev self.ep_data_in, self.ep_stat_in, self.ep_data_out, self.ep_cmd_out = ep_data_in, ep_stat_in, ep_data_out, ep_cmd_out self.max_streams, self.max_read_len = max_streams, max_read_len @@ -150,23 +150,20 @@ class ASM24Controller: parts = self.exec_ops([ReadOp(base_addr + off, min(stride, length - off)) for off in range(0, length, stride)]) return b''.join(p or b'' for p in parts)[:length] - def pcie_request(self, fmt_type, address, value=None, size=4, cnt=10): + def pcie_prep_request(self, fmt_type:int, address:int, value:int|None=None, size:int=4) -> list[WriteOp]: assert fmt_type >> 8 == 0 and size > 0 and size <= 4, f"Invalid fmt_type {fmt_type} or size {size}" - if DEBUG >= 3: print("pcie_request", hex(fmt_type), hex(address), value, size, cnt) + if DEBUG >= 3: print("pcie_prep_req", hex(fmt_type), hex(address), value, size) masked_address, offset = address & 0xFFFFFFFC, address & 0x3 - assert size + offset <= 4 + assert size + offset <= 4 and (value is None or value >> (8 * size) == 0) - ops = [] - if value is not None: - assert value >> (8 * size) == 0 - ops.append(WriteOp(0xB220, struct.pack('>I', value << (8 * offset)), ignore_cache=False)) + return ([WriteOp(0xB220, struct.pack('>I', value << (8 * offset)), ignore_cache=False)] if value is not None else []) + \ + [WriteOp(0xB218, struct.pack('>I', masked_address), ignore_cache=False), + WriteOp(0xB217, bytes([((1 << size) - 1) << offset]), ignore_cache=False), WriteOp(0xB210, bytes([fmt_type]), ignore_cache=False), + WriteOp(0xB254, b"\x0f", ignore_cache=True), WriteOp(0xB296, b"\x04", ignore_cache=True)] - ops += [WriteOp(0xB218, struct.pack('>I', masked_address), ignore_cache=False), - WriteOp(0xB217, bytes([((1 << size) - 1) << offset]), ignore_cache=False), - WriteOp(0xB210, bytes([fmt_type]), ignore_cache=False), - WriteOp(0xB254, b"\x0f", ignore_cache=True), WriteOp(0xB296, b"\x04", ignore_cache=True)] - self.exec_ops(ops) + def pcie_request(self, fmt_type, address, value=None, size=4, cnt=10): + self.exec_ops(self.pcie_prep_request(fmt_type, address, value, size)) # Fast path for write requests if ((fmt_type & 0b11011111) == 0b01000000) or ((fmt_type & 0b10111000) == 0b00110000): return @@ -194,7 +191,7 @@ class ASM24Controller: 0b100: "Completer Abort: abort due to internal error", 0b010: "Configuration Request Retry Status: configuration space busy"} raise RuntimeError(f"TLP status: {status_map.get(status, 'Reserved (0b{:03b})'.format(status))}") - if value is None: return (struct.unpack('>I', self.read(0xB220, 4))[0] >> (8 * offset)) & ((1 << (8 * size)) - 1) + if value is None: return (struct.unpack('>I', self.read(0xB220, 4))[0] >> (8 * (address & 0x3))) & ((1 << (8 * size)) - 1) def pcie_cfg_req(self, byte_addr, bus=1, dev=0, fn=0, value=None, size=4): assert byte_addr >> 12 == 0 and bus >> 8 == 0 and dev >> 5 == 0 and fn >> 3 == 0, f"Invalid byte_addr {byte_addr}, bus {bus}, dev {dev}, fn {fn}" @@ -205,6 +202,12 @@ class ASM24Controller: def pcie_mem_req(self, address, value=None, size=4): return self.pcie_request(0x40 if value is not None else 0x0, address, value, size) + def pcie_mem_write(self, address, values, size): + ops = [self.pcie_prep_request(0x40, address + i * size, value, size) for i, value in enumerate(values)] + + # Send in batches of 4 + for i in range(0, len(ops), 4): self.exec_ops(list(itertools.chain.from_iterable(ops[i:i+4]))) + class USBMMIOInterface(MMIOInterface): def __init__(self, usb, addr, size, fmt, pcimem=True): self.usb, self.addr, self.nbytes, self.fmt, self.pcimem, self.el_sz = usb, addr, size, fmt, pcimem, struct.calcsize(fmt) @@ -240,4 +243,4 @@ class USBMMIOInterface(MMIOInterface): return self.usb.scsi_write(bytes(data)) if self.addr == 0xf000 else self.usb.write(self.addr + off, bytes(data)) _, acc_sz = self._acc_size(len(data) * struct.calcsize(self.fmt)) - for i in range(0, len(data), acc_sz): self._acc_one(off + i, acc_sz, int.from_bytes(data[i:i+acc_sz], "little")) + self.usb.pcie_mem_write(self.addr+off, [int.from_bytes(data[i:i+acc_sz], "little") for i in range(0, len(data), acc_sz)], acc_sz)