diff --git a/test/mockgpu/usb.py b/test/mockgpu/usb.py index a7fba8790c..460b46100e 100644 --- a/test/mockgpu/usb.py +++ b/test/mockgpu/usb.py @@ -5,7 +5,7 @@ class MockUSB: def read(self, address, size): return bytes(self.mem[address:address+size]) - def write(self, address, data): + def write(self, address, data, ignore_cache=False): self.mem[address:address+len(data)] = data def pcie_mem_req(self, address, value=None, size=1): diff --git a/test/test_hcq_iface.py b/test/test_hcq_iface.py index 2a3e0974cb..bf32bf637c 100644 --- a/test/test_hcq_iface.py +++ b/test/test_hcq_iface.py @@ -62,7 +62,7 @@ class TestUSBMMIOInterface(unittest.TestCase): def test_getitem_setitem_byte(self): self.mmio[1] = 0xAB - self.assertEqual(self.mmio[1], bytes([0xAB])) + self.assertEqual(self.mmio[1], 0xAB) self.assertEqual(self.usb.mem[1], 0xAB) def test_slice_getitem_setitem(self): @@ -76,13 +76,13 @@ class TestUSBMMIOInterface(unittest.TestCase): def test_view(self): self.mmio[0] = 5 view = self.mmio.view(offset=1, size=3) - self.assertEqual(view[0], bytes([self.usb.mem[1]])) + self.assertEqual(view[0], self.usb.mem[1]) view[:] = [7, 8, 9] self.assertEqual(list(self.usb.mem[1:4]), [7, 8, 9]) full_view = self.mmio.view() self.assertEqual(len(full_view), len(self.mmio)) self.mmio[2] = 0xFE - self.assertEqual(full_view[2], bytes([0xFE])) + self.assertEqual(full_view[2], 0xFE) def test_pcimem_byte(self): usb2 = MockUSB(bytearray(self.size)) diff --git a/tinygrad/runtime/support/usb.py b/tinygrad/runtime/support/usb.py index 83c6f6b9de..0df50515e3 100644 --- a/tinygrad/runtime/support/usb.py +++ b/tinygrad/runtime/support/usb.py @@ -231,7 +231,8 @@ class USBMMIOInterface(MMIOInterface): def _acc(self, off, sz, data=None): if data is None: # read op - if not self.pcimem: return self.usb.read(self.addr + off, sz) + if not self.pcimem: + return int.from_bytes(self.usb.read(self.addr + off, sz), "little") if sz == self.el_sz else self.usb.read(self.addr + off, sz) acc, acc_size = self._acc_size(sz) return bytes(array.array(acc, [self._acc_one(off + i * acc_size, acc_size) for i in range(sz // acc_size)])) @@ -240,7 +241,8 @@ class USBMMIOInterface(MMIOInterface): if not self.pcimem: # Fast path for writing into buffer 0xf000 - return self.usb.scsi_write(bytes(data)) if self.addr == 0xf000 else self.usb.write(self.addr + off, bytes(data)) + use_cache = 0xa000 <= self.addr <= 0xb200 + return self.usb.scsi_write(bytes(data)) if self.addr == 0xf000 else self.usb.write(self.addr + off, bytes(data), ignore_cache=not use_cache) _, acc_sz = self._acc_size(len(data) * struct.calcsize(self.fmt)) self.usb.pcie_mem_write(self.addr+off, [int.from_bytes(data[i:i+acc_sz], "little") for i in range(0, len(data), acc_sz)], acc_sz)