mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
fast path for copy (#2548)
* fast copy * ruff first * flat_mv on malloc * order + webgpu test
This commit is contained in:
@@ -7,18 +7,18 @@ repos:
|
|||||||
language: system
|
language: system
|
||||||
always_run: true
|
always_run: true
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
- id: mypy
|
|
||||||
name: mypy
|
|
||||||
entry: mypy tinygrad/ extra/helpers.py
|
|
||||||
language: system
|
|
||||||
always_run: true
|
|
||||||
pass_filenames: false
|
|
||||||
- id: ruff
|
- id: ruff
|
||||||
name: ruff
|
name: ruff
|
||||||
entry: ruff . --preview
|
entry: ruff . --preview
|
||||||
language: system
|
language: system
|
||||||
always_run: true
|
always_run: true
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
- id: mypy
|
||||||
|
name: mypy
|
||||||
|
entry: mypy tinygrad/ extra/helpers.py
|
||||||
|
language: system
|
||||||
|
always_run: true
|
||||||
|
pass_filenames: false
|
||||||
- id: docs
|
- id: docs
|
||||||
name: docs
|
name: docs
|
||||||
entry: python3 docs/abstractions.py
|
entry: python3 docs/abstractions.py
|
||||||
|
|||||||
@@ -19,14 +19,14 @@ function cleanup(err) {
|
|||||||
async function waitForText(selector, text) {
|
async function waitForText(selector, text) {
|
||||||
let n = 0;
|
let n = 0;
|
||||||
let ready = false;
|
let ready = false;
|
||||||
while (n < 10) {
|
while (n < 30) {
|
||||||
const res = await (await selector.getProperty("textContent")).jsonValue();
|
const res = await (await selector.getProperty("textContent")).jsonValue();
|
||||||
console.log(`waiting for text ${text} got ${res}`);
|
console.log(`waiting for text ${text} got ${res}`);
|
||||||
if(res == text) {
|
if(res == text) {
|
||||||
ready = true;
|
ready = true;
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
await timeout(2000);
|
await timeout(1000);
|
||||||
n += 1
|
n += 1
|
||||||
}
|
}
|
||||||
return ready;
|
return ready;
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import numpy as np
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable, Tuple
|
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable, Tuple
|
||||||
import importlib, inspect, functools, pathlib, time, re
|
import importlib, inspect, functools, pathlib, time, re
|
||||||
from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name, DType, from_mv, dtypes
|
from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name, DType, from_mv, dtypes, flat_mv
|
||||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||||
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, Op
|
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, Op
|
||||||
|
|
||||||
@@ -45,8 +45,28 @@ class Buffer:
|
|||||||
if self.device == Device.DEFAULT: GlobalCounters.mem_used -= self.size * self.dtype.itemsize
|
if self.device == Device.DEFAULT: GlobalCounters.mem_used -= self.size * self.dtype.itemsize
|
||||||
self.allocator.free(self._buf, self.size, self.dtype)
|
self.allocator.free(self._buf, self.size, self.dtype)
|
||||||
def __repr__(self): return f"<buf device:{self.device} size:{self.size}>"
|
def __repr__(self): return f"<buf device:{self.device} size:{self.size}>"
|
||||||
|
def copy_(self, src:Buffer):
|
||||||
|
assert self.size == src.size and self.dtype == src.dtype, "buffer copy size/dtype mismatch"
|
||||||
|
if hasattr(self.allocator, 'transfer') and type(self.allocator) is type(src.allocator):
|
||||||
|
# fast path, used on HIP between GPUs
|
||||||
|
self.allocator.transfer(self._buf, src._buf, self.size*self.dtype.itemsize)
|
||||||
|
return
|
||||||
|
if hasattr(self.allocator, 'from_buffer') and hasattr(self.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'):
|
||||||
|
# fast path, used on Metal in OS X Sonoma
|
||||||
|
fb = self.allocator.from_buffer(src.allocator.as_buffer(src._buf))
|
||||||
|
if fb:
|
||||||
|
self.allocator.transfer(self._buf, fb, self.size*self.dtype.itemsize)
|
||||||
|
return
|
||||||
|
if hasattr(self.allocator, 'as_buffer'):
|
||||||
|
# fast(ish) path, uses readinto in diskbuffers
|
||||||
|
src.allocator.copyout(self.allocator.as_buffer(self._buf), src._buf)
|
||||||
|
elif hasattr(src.allocator, 'as_buffer'):
|
||||||
|
self.allocator.copyin(self._buf, src.allocator.as_buffer(src._buf))
|
||||||
|
else:
|
||||||
|
# slow path, allocates a CPU buffer
|
||||||
|
self.copyin(src.toCPU().data)
|
||||||
def copyin(self, mv:memoryview):
|
def copyin(self, mv:memoryview):
|
||||||
mv = mv.cast("B", shape=[self.size*self.dtype.itemsize])
|
mv = flat_mv(mv)
|
||||||
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||||
self.allocator.copyin(self._buf, mv)
|
self.allocator.copyin(self._buf, mv)
|
||||||
return self
|
return self
|
||||||
@@ -56,7 +76,7 @@ class Buffer:
|
|||||||
# zero copy with as_buffer
|
# zero copy with as_buffer
|
||||||
if hasattr(self.allocator, 'as_buffer'): return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf}))
|
if hasattr(self.allocator, 'as_buffer'): return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf}))
|
||||||
ret = np.empty(self.size, self.dtype.np)
|
ret = np.empty(self.size, self.dtype.np)
|
||||||
if self.size > 0: self.allocator.copyout(ret.data.cast("B", shape=[self.size*self.dtype.itemsize]), self._buf)
|
if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# TODO: size, dest, src are the same type. can we enforce this?
|
# TODO: size, dest, src are the same type. can we enforce this?
|
||||||
@@ -255,7 +275,7 @@ def _get_optimized_linearizer(linearizer_opts:LinearizerOptions, ast:LazyOp) ->
|
|||||||
import ctypes
|
import ctypes
|
||||||
class _MallocAllocator(LRUAllocator):
|
class _MallocAllocator(LRUAllocator):
|
||||||
def _alloc(self, size:int, dtype:DType): return (ctypes.c_uint8 * (size*dtype.itemsize))()
|
def _alloc(self, size:int, dtype:DType): return (ctypes.c_uint8 * (size*dtype.itemsize))()
|
||||||
def as_buffer(self, src) -> memoryview: return memoryview(src)
|
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
||||||
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
||||||
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
||||||
MallocAllocator = _MallocAllocator()
|
MallocAllocator = _MallocAllocator()
|
||||||
|
|||||||
@@ -65,7 +65,9 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
|||||||
return CStruct
|
return CStruct
|
||||||
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
||||||
def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1]
|
def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1]
|
||||||
|
def flat_mv(mv:memoryview):
|
||||||
|
if len(mv) == 0: return mv
|
||||||
|
return mv.cast("B", shape=(mv.nbytes,))
|
||||||
class Context(contextlib.ContextDecorator):
|
class Context(contextlib.ContextDecorator):
|
||||||
stack: ClassVar[List[dict[str, int]]] = [{}]
|
stack: ClassVar[List[dict[str, int]]] = [{}]
|
||||||
def __init__(self, **kwargs): self.kwargs = kwargs
|
def __init__(self, **kwargs): self.kwargs = kwargs
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
|
|||||||
# TODO: size 0 should be removed from the schedule
|
# TODO: size 0 should be removed from the schedule
|
||||||
if si.out.realized.size != 0:
|
if si.out.realized.size != 0:
|
||||||
if si.ast.op in LoadOps:
|
if si.ast.op in LoadOps:
|
||||||
|
if DEBUG >= 2: print(f"*** {si.ast.op:>15s} {f'{si.out.device} <- {si.inputs[0].device}' if si.ast.op is LoadOps.FROM else si.out.device:25s} sz {si.out.realized.size:5d} shape {si.out.shape} dtype {si.out.dtype} arg {si.ast.arg}")
|
||||||
# confirm the LoadOps are contiguous and in order
|
# confirm the LoadOps are contiguous and in order
|
||||||
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||||
kwargs = {"arg": si.ast.arg} if si.ast.arg is not None else {}
|
kwargs = {"arg": si.ast.arg} if si.ast.arg is not None else {}
|
||||||
@@ -41,34 +42,19 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
|
|||||||
#assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
|
#assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
|
||||||
assert si.out.realized.dtype == si.out.dtype, f"realized dtype is incorrect, {si.out.realized.dtype} != {si.out.dtype}"
|
assert si.out.realized.dtype == si.out.dtype, f"realized dtype is incorrect, {si.out.realized.dtype} != {si.out.dtype}"
|
||||||
|
|
||||||
# *** zero op LoadOps ***
|
# *** LoadOps implementation ***
|
||||||
|
|
||||||
def _realize_empty(buffer: Buffer) -> None:
|
|
||||||
if DEBUG >= 2: print(f"*** empty {buffer.device} shape {buffer.size:5d} dtype {buffer.dtype}")
|
|
||||||
|
|
||||||
# TODO: remove this and write the RNG in tinygrad
|
# TODO: remove this and write the RNG in tinygrad
|
||||||
def _realize_rand(buffer: Buffer, arg) -> None:
|
def _realize_rand(buffer: Buffer, arg) -> None:
|
||||||
if DEBUG >= 2: print(f"*** rand {buffer.device} seed {arg:<10d} shape {buffer.size:5d} dtype {buffer.dtype}")
|
|
||||||
rng = np.random.default_rng(arg)
|
rng = np.random.default_rng(arg)
|
||||||
rng_np_buffer = rng.random(size=buffer.size, dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False)
|
rng_np_buffer = rng.random(size=buffer.size, dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False)
|
||||||
buffer.copyin(rng_np_buffer.data)
|
buffer.copyin(rng_np_buffer.data)
|
||||||
|
|
||||||
# *** one op LoadOps ***
|
def _realize_custom(*buffers: Buffer, arg) -> None: arg(*buffers)
|
||||||
|
|
||||||
def _realize_from(buffer: Buffer, src: Buffer) -> None:
|
|
||||||
assert src.size == buffer.size, f"size mismatch on FROM {src.size=} != {buffer.size=}"
|
|
||||||
if DEBUG >= 2: print(f"*** copy {buffer.device} <- {src.device} size {src.size:<16d} shape {buffer.size:5d} dtype {src.dtype}")
|
|
||||||
buffer.copyin(src.toCPU().data)
|
|
||||||
|
|
||||||
# *** n op LoadOps ***
|
|
||||||
|
|
||||||
def _realize_custom(buffer: Buffer, *inputs: Buffer, arg) -> None:
|
|
||||||
if DEBUG >= 2: print(f"*** custom {buffer.device} shape {buffer.size:5d} dtype {buffer.dtype}")
|
|
||||||
arg(buffer, *inputs)
|
|
||||||
|
|
||||||
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
|
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
|
||||||
LoadOps.EMPTY: _realize_empty,
|
LoadOps.EMPTY: lambda x: None,
|
||||||
LoadOps.RAND: _realize_rand,
|
LoadOps.RAND: _realize_rand,
|
||||||
LoadOps.FROM: _realize_from,
|
LoadOps.FROM: Buffer.copy_,
|
||||||
LoadOps.CUSTOM: _realize_custom,
|
LoadOps.CUSTOM: _realize_custom
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Callable, Dict, Tuple
|
from typing import Callable, Dict, Tuple
|
||||||
from tinygrad.helpers import dtypes, DType
|
from tinygrad.helpers import dtypes, DType, flat_mv
|
||||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op
|
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op
|
||||||
from tinygrad.device import Interpreted, Allocator
|
from tinygrad.device import Interpreted, Allocator
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = {
|
|||||||
|
|
||||||
class NumpyAllocator(Allocator):
|
class NumpyAllocator(Allocator):
|
||||||
def _alloc(self, size:int, dtype:DType): return np.empty(size, dtype.np)
|
def _alloc(self, size:int, dtype:DType): return np.empty(size, dtype.np)
|
||||||
def as_buffer(self, src:np.ndarray) -> memoryview: return np.require(src, requirements='C').data
|
def as_buffer(self, src:np.ndarray) -> memoryview: return flat_mv(np.require(src, requirements='C').data)
|
||||||
def copyin(self, dest:np.ndarray, src:memoryview): np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape))
|
def copyin(self, dest:np.ndarray, src:memoryview): np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape))
|
||||||
def copyout(self, dest:memoryview, src:np.ndarray): np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src)
|
def copyout(self, dest:memoryview, src:np.ndarray): np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src)
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class DiskBuffer:
|
|||||||
def as_strided(self, arg):
|
def as_strided(self, arg):
|
||||||
assert strides_for_shape(arg[0]) == arg[1], "disk tensors don't support strides"
|
assert strides_for_shape(arg[0]) == arg[1], "disk tensors don't support strides"
|
||||||
return DiskBuffer(self.ud, prod(arg[0]), self.dtype, offset=self.offset+arg[2]*self.dtype.itemsize)
|
return DiskBuffer(self.ud, prod(arg[0]), self.dtype, offset=self.offset+arg[2]*self.dtype.itemsize)
|
||||||
def _buf(self) -> memoryview: return memoryview(self.ud.mem).cast("B")[self.offset:self.offset+self.size*self.dtype.itemsize]
|
def _buf(self) -> memoryview: return memoryview(self.ud.mem)[self.offset:self.offset+self.size*self.dtype.itemsize]
|
||||||
|
|
||||||
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.CAST: DiskBuffer.cast, MovementOps.AS_STRIDED: DiskBuffer.as_strided }
|
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.CAST: DiskBuffer.cast, MovementOps.AS_STRIDED: DiskBuffer.as_strided }
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import os, subprocess, pathlib, ctypes, tempfile, functools
|
import os, subprocess, pathlib, ctypes, tempfile, functools
|
||||||
import Metal, libdispatch
|
import Metal, libdispatch
|
||||||
from typing import List, Any, Tuple
|
from typing import List, Any, Tuple, Optional
|
||||||
from tinygrad.codegen.kernel import LinearizerOptions
|
from tinygrad.codegen.kernel import LinearizerOptions
|
||||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, diskcache, unwrap2
|
from tinygrad.helpers import prod, getenv, DEBUG, DType, diskcache, unwrap2
|
||||||
from tinygrad.device import Compiled, LRUAllocator
|
from tinygrad.device import Compiled, LRUAllocator
|
||||||
@@ -51,32 +51,25 @@ class MetalAllocator(LRUAllocator):
|
|||||||
def __init__(self, device:MetalDevice):
|
def __init__(self, device:MetalDevice):
|
||||||
self.device:MetalDevice = device
|
self.device:MetalDevice = device
|
||||||
super().__init__()
|
super().__init__()
|
||||||
def _alloc(self, size:int, dtype:DType):
|
def _alloc(self, size:int, dtype:DType) -> Any:
|
||||||
if size == 0: return None
|
if size == 0: return None
|
||||||
ret = self.device.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
|
ret = self.device.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
|
||||||
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=} {dtype=}")
|
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=} {dtype=}")
|
||||||
return ret
|
return ret
|
||||||
def _async_copy(self, dest, src):
|
def transfer(self, dest:Any, src:Any, sz:int):
|
||||||
assert src.length() == dest.length(), f"length mismatch {src.length()=} {dest.length()=}"
|
|
||||||
command_buffer = self.device.mtl_queue.commandBuffer()
|
command_buffer = self.device.mtl_queue.commandBuffer()
|
||||||
encoder = command_buffer.blitCommandEncoder()
|
encoder = command_buffer.blitCommandEncoder()
|
||||||
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src, 0, dest, 0, src.length())
|
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src, 0, dest, 0, sz)
|
||||||
encoder.endEncoding()
|
encoder.endEncoding()
|
||||||
command_buffer.commit()
|
command_buffer.commit()
|
||||||
self.device.mtl_buffers_in_flight.append(command_buffer)
|
self.device.mtl_buffers_in_flight.append(command_buffer)
|
||||||
def _from_buffer(self, src:memoryview): return self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_(src, len(src), Metal.MTLResourceStorageModeShared, None)
|
def from_buffer(self, src:memoryview) -> Optional[Any]: return self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_(src, len(src), Metal.MTLResourceStorageModeShared, None)
|
||||||
def _free(self, opaque): opaque.release()
|
def _free(self, opaque:Any): opaque.release()
|
||||||
def as_buffer(self, src) -> memoryview:
|
def as_buffer(self, src:Any) -> memoryview:
|
||||||
self.device.synchronize()
|
self.device.synchronize()
|
||||||
return src.contents().as_buffer(src.length())
|
return src.contents().as_buffer(src.length())
|
||||||
def copyin(self, dest, src:memoryview):
|
def copyin(self, dest:Any, src:memoryview): self.as_buffer(dest)[:] = src
|
||||||
src_from_buffer = None if getenv("SLOW_METAL_COPY") else self._from_buffer(src)
|
def copyout(self, dest:memoryview, src:Any): dest[:] = self.as_buffer(src)
|
||||||
if src_from_buffer is None:
|
|
||||||
self.as_buffer(dest)[:] = src
|
|
||||||
else:
|
|
||||||
self.device.copies_in_flight.append(src)
|
|
||||||
self._async_copy(dest, src_from_buffer)
|
|
||||||
def copyout(self, dest:memoryview, src): dest[:] = self.as_buffer(src)
|
|
||||||
|
|
||||||
class MetalDevice(Compiled):
|
class MetalDevice(Compiled):
|
||||||
compiler_device = None
|
compiler_device = None
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import numpy as np
|
|||||||
from typing import Dict, Callable
|
from typing import Dict, Callable
|
||||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, ReduceOps, Op
|
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, ReduceOps, Op
|
||||||
from tinygrad.device import Interpreted, Allocator
|
from tinygrad.device import Interpreted, Allocator
|
||||||
from tinygrad.helpers import getenv, dtypes, DType
|
from tinygrad.helpers import getenv, dtypes, DType, flat_mv
|
||||||
from tinygrad.runtime.ops_cpu import einsum_mulacc, shape_to_axis
|
from tinygrad.runtime.ops_cpu import einsum_mulacc, shape_to_axis
|
||||||
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||||
@@ -43,7 +43,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {
|
|||||||
|
|
||||||
class TorchAllocator(Allocator):
|
class TorchAllocator(Allocator):
|
||||||
def _alloc(self, size:int, dtype:DType): return torch.empty([size], device=device, dtype=inverse_type_map[dtype])
|
def _alloc(self, size:int, dtype:DType): return torch.empty([size], device=device, dtype=inverse_type_map[dtype])
|
||||||
def as_buffer(self, src:torch.Tensor) -> memoryview: return np.require(src.numpy(), requirements='C').data
|
def as_buffer(self, src:torch.Tensor) -> memoryview: return flat_mv(np.require(src.numpy(), requirements='C').data)
|
||||||
def copyin(self, dest:torch.Tensor, src:memoryview): dest.copy_(torch.frombuffer(src, dtype=dest.dtype))
|
def copyin(self, dest:torch.Tensor, src:memoryview): dest.copy_(torch.frombuffer(src, dtype=dest.dtype))
|
||||||
def copyout(self, dest:memoryview, src:torch.Tensor): torch.frombuffer(dest, dtype=src.dtype).copy_(src.flatten())
|
def copyout(self, dest:memoryview, src:torch.Tensor): torch.frombuffer(dest, dtype=src.dtype).copy_(src.flatten())
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user