use host ptr for speed on copyouts (#1393)

* feat: use mapped buffer for speed

* fix: whoops don't need that

* feat: don't need explicit call to memoryview
This commit is contained in:
wozeparrot
2023-08-01 12:34:12 -04:00
committed by GitHub
parent ba5e3818a0
commit 7c7cf16ef2

View File

@@ -43,13 +43,15 @@ class CLBuffer(RawBufferCopyInOut):
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
super().__init__(size, dtype, buf)
def _copyin(self, x: np.ndarray):
def _copyin(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
CL.events_in_flight.append(cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False))
def _copyout(self, x:np.ndarray):
CL.synchronize()
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
cl.enqueue_copy(CL.cl_queue[self._buf.device], x, self._buf, is_blocking=True)
buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data)
mapped = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np)
CL.synchronize()
with mapped[0].base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped[0], self._buf, is_blocking=True, wait_for=[mapped[1]])
class CLProgram:
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):