diff --git a/tinygrad/runtime/support/usb.py b/tinygrad/runtime/support/usb.py index 920d798aba..5e64b207ba 100644 --- a/tinygrad/runtime/support/usb.py +++ b/tinygrad/runtime/support/usb.py @@ -1,7 +1,7 @@ import ctypes, struct, dataclasses, array, itertools from typing import Sequence from tinygrad.runtime.autogen import libusb -from tinygrad.helpers import DEBUG +from tinygrad.helpers import DEBUG, to_mv from tinygrad.runtime.support.hcq import MMIOInterface class USB3: @@ -46,6 +46,7 @@ class USB3: self.buf_stat = [(ctypes.c_uint8 * 64)() for _ in range(self.max_streams)] self.buf_data_in = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)] self.buf_data_out = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)] + self.buf_data_out_mvs = [to_mv(ctypes.addressof(self.buf_data_out[i]), 0x1000) for i in range(self.max_streams)] def _prep_transfer(self, tr, ep, stream_id, buf, length): tr.contents.dev_handle, tr.contents.endpoint, tr.contents.length, tr.contents.buffer = self.handle, ep, length, buf @@ -86,8 +87,11 @@ class USB3: tr_window.append(self._prep_transfer(self.tr[self.ep_data_in][slot], self.ep_data_in, stream, self.buf_data_in[slot], rlen)) if send_data is not None: - if len(send_data) > len(self.buf_data_out[slot]): self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))() - self.buf_data_out[slot][:len(send_data)] = list(send_data) + if len(send_data) > len(self.buf_data_out[slot]): + self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))() + self.buf_data_out_mvs[slot] = to_mv(ctypes.addressof(self.buf_data_out[slot]), len(send_data)) + + self.buf_data_out_mvs[slot][:len(send_data)] = bytes(send_data) tr_window.append(self._prep_transfer(self.tr[self.ep_data_out][slot], self.ep_data_out, stream, self.buf_data_out[slot], len(send_data))) op_window.append((idx, slot, rlen))