Refactor LoadOps (#910)

* test

* work

* upd test

* loadops

* cleanups

* real ones

* remove LazyNumpyArray

* fix assign test

* remove range

* np.require

* llama uses arange kernels

* no caching consts

* fix enet

* torch load support

* tests cleanup

* fix shufflenet

* fix image

* fix torch_load test
This commit is contained in:
George Hotz
2023-06-03 09:40:43 -07:00
committed by GitHub
parent d58586bb17
commit 791530045d
20 changed files with 254 additions and 117 deletions

View File

@@ -43,7 +43,7 @@ class CLBuffer(RawBufferCopyInOut):
super().__init__(size, dtype, buf)
def _copyin(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, x, is_blocking=False)
cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False)
def _copyout(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
cl.enqueue_copy(CL.cl_queue[self._buf.device], x, self._buf, is_blocking=True)