mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fast path for copy (#2548)
* fast copy * ruff first * flat_mv on malloc * order + webgpu test
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
from typing import Callable, Dict, Tuple
|
||||
from tinygrad.helpers import dtypes, DType
|
||||
from tinygrad.helpers import dtypes, DType, flat_mv
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op
|
||||
from tinygrad.device import Interpreted, Allocator
|
||||
|
||||
@@ -41,7 +41,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = {
|
||||
|
||||
class NumpyAllocator(Allocator):
|
||||
def _alloc(self, size:int, dtype:DType): return np.empty(size, dtype.np)
|
||||
def as_buffer(self, src:np.ndarray) -> memoryview: return np.require(src, requirements='C').data
|
||||
def as_buffer(self, src:np.ndarray) -> memoryview: return flat_mv(np.require(src, requirements='C').data)
|
||||
def copyin(self, dest:np.ndarray, src:memoryview): np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape))
|
||||
def copyout(self, dest:memoryview, src:np.ndarray): np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user