mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 09:05:40 -05:00
remove cpu and torch backends (#3399)
* remove cpu and torch backends * don't copy to cpu * use clang instead of cpu * multitensor gathers on the first device * clang is cpu + use default * fixup * bugfix
This commit is contained in:
@@ -1,48 +0,0 @@
|
||||
import numpy as np
|
||||
from typing import Callable, Dict, Tuple
|
||||
from tinygrad.helpers import flat_mv
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
|
||||
from tinygrad.device import Interpreted, Allocator
|
||||
|
||||
def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
assert len(in_shape) == len(out_shape), "reduce shapes must have same dimensions"
|
||||
return tuple(i for i,(a,b) in enumerate(zip(in_shape, out_shape)) if a != b)
|
||||
|
||||
def einsum_mulacc(einsum, get_strides, expand):
|
||||
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
|
||||
def get_input_axes(t, sum_axes): return tuple(i for i,stride in enumerate(get_strides(t)) if stride != 0 or i in sum_axes)
|
||||
def get_sliced_input(t, axes): return t[tuple(slice(None) if i in axes else 0 for i in range(len(get_strides(t))))]
|
||||
def mulacc(a, b, out_shape):
|
||||
sum_axes = tuple(i for i,s in enumerate(out_shape) if s == 1)
|
||||
a_axes, b_axes = get_input_axes(a, sum_axes), get_input_axes(b, sum_axes)
|
||||
a_input, b_input = get_sliced_input(a, a_axes), get_sliced_input(b, b_axes)
|
||||
out_axes = [i for i in range(len(out_shape)) if (i in a_axes or i in b_axes) and i not in sum_axes]
|
||||
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out_axes)}", a_input, b_input)
|
||||
return expand(ret.reshape(tuple(1 if (i not in a_axes and i not in b_axes) else s for i,s in enumerate(out_shape))), out_shape)
|
||||
return mulacc
|
||||
|
||||
def as_strided(x, arg):
|
||||
shape, stride, offset = arg
|
||||
return np.ndarray(shape, x.dtype, buffer=np.require(x, requirements='C'), offset=offset*x.dtype.itemsize,
|
||||
strides=tuple(y*x.dtype.itemsize for y in stride))
|
||||
|
||||
numpy_fxn_for_op: Dict[Op, Callable] = {
|
||||
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
|
||||
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.SQRT: np.sqrt, UnaryOps.NEG: np.negative,
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: np.less, BinaryOps.CMPEQ: np.equal, BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract,
|
||||
BinaryOps.MUL: np.multiply, BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(x.dtype, copy=False), BinaryOps.XOR: np.bitwise_xor,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(reduce_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
||||
ReduceOps.MAX: lambda x, new_shape: x.max(reduce_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy(), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
TernaryOps.WHERE: np.where, MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: np.broadcast_to, MovementOps.PAD: np.pad
|
||||
}
|
||||
|
||||
class NumpyAllocator(Allocator):
|
||||
def _alloc(self, size:int): return np.empty(size, dtype=np.uint8)
|
||||
def as_buffer(self, src:np.ndarray) -> memoryview: return flat_mv(np.require(src, requirements='C').data)
|
||||
def copyin(self, dest:np.ndarray, src:memoryview): np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape))
|
||||
def copyout(self, dest:memoryview, src:np.ndarray): np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src)
|
||||
|
||||
class CPUDevice(Interpreted):
|
||||
def __init__(self, device:str): super().__init__(device, NumpyAllocator(), numpy_fxn_for_op)
|
||||
@@ -1,45 +0,0 @@
|
||||
import torch
|
||||
from typing import Dict, Callable
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
|
||||
from tinygrad.device import Interpreted, Allocator
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import getenv, flatten
|
||||
from tinygrad.runtime.ops_cpu import einsum_mulacc, reduce_axis
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
type_map = {torch.bool: dtypes.bool,
|
||||
torch.int8: dtypes.int8, torch.uint8: dtypes.uint8, torch.int16: dtypes.int16, torch.int32: dtypes.int32, torch.int64: dtypes.int64,
|
||||
torch.float16: dtypes.float16, torch.bfloat16: dtypes.bfloat16, torch.float32: dtypes.float32, torch.float64: dtypes.float64}
|
||||
inverse_type_map = {v: k for k,v in type_map.items()}
|
||||
# TODO: should unsupported types fail instead of implicit conversion?
|
||||
inverse_type_map.update({dtypes.uint16: torch.int16, dtypes.uint32: torch.int32, dtypes.uint64: torch.int64})
|
||||
|
||||
def as_strided(x, arg):
|
||||
shape, stride, offset = arg
|
||||
x = x.contiguous()
|
||||
offset += x.storage_offset() # NOTE: contiguous can still have a storage_offset, so we adjust for it
|
||||
if any(i < 0 for i in stride):
|
||||
return torch.as_strided(x, shape, tuple(abs(i) for i in stride),
|
||||
offset + sum((s-1)*a if a < 0 else 0 for (s,a) in zip(shape, stride))).flip([i for i,a in enumerate(stride) if a < 0])
|
||||
return torch.as_strided(x, shape, stride, offset)
|
||||
|
||||
torch_fxn_for_op: Dict[Op, Callable] = {
|
||||
BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
|
||||
UnaryOps.EXP2: torch.exp2, UnaryOps.LOG2: torch.log2, UnaryOps.SIN: torch.sin, UnaryOps.SQRT: torch.sqrt, UnaryOps.NEG: torch.neg,
|
||||
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]),
|
||||
BinaryOps.ADD: torch.add, BinaryOps.SUB: torch.sub, BinaryOps.MUL: torch.mul, BinaryOps.DIV: lambda x,y: torch.div(x, y).type(x.dtype),
|
||||
BinaryOps.XOR: torch.bitwise_xor, BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: torch.lt, BinaryOps.CMPEQ: torch.eq,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(reduce_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
||||
ReduceOps.MAX: lambda x, new_shape: x.amax(reduce_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
TernaryOps.WHERE: torch.where, MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: lambda x, arg: x.expand(arg),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, flatten(padding[::-1])),
|
||||
}
|
||||
|
||||
class TorchAllocator(Allocator):
|
||||
def _alloc(self, size:int): return torch.empty([size], device=device, dtype=torch.uint8)
|
||||
def copyin(self, dest:torch.Tensor, src:memoryview): dest.copy_(torch.frombuffer(src, dtype=dest.dtype))
|
||||
def copyout(self, dest:memoryview, src:torch.Tensor): torch.frombuffer(dest, dtype=src.dtype).copy_(src.flatten())
|
||||
|
||||
class TorchDevice(Interpreted):
|
||||
def __init__(self, device:str): super().__init__(device, TorchAllocator(), torch_fxn_for_op)
|
||||
Reference in New Issue
Block a user