remove cpu and torch backends (#3399)

* remove cpu and torch backends

* don't copy to cpu

* use clang instead of cpu

* multitensor gathers on the first device

* clang is cpu + use default

* fixup

* bugfix
This commit is contained in:
George Hotz
2024-02-15 16:55:39 +01:00
committed by GitHub
parent 75f7e21a80
commit b1c0d8c99d
12 changed files with 23 additions and 151 deletions

View File

@@ -1,48 +0,0 @@
import numpy as np
from typing import Callable, Dict, Tuple
from tinygrad.helpers import flat_mv
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
from tinygrad.device import Interpreted, Allocator
def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[int, ...]:
assert len(in_shape) == len(out_shape), "reduce shapes must have same dimensions"
return tuple(i for i,(a,b) in enumerate(zip(in_shape, out_shape)) if a != b)
def einsum_mulacc(einsum, get_strides, expand):
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
def get_input_axes(t, sum_axes): return tuple(i for i,stride in enumerate(get_strides(t)) if stride != 0 or i in sum_axes)
def get_sliced_input(t, axes): return t[tuple(slice(None) if i in axes else 0 for i in range(len(get_strides(t))))]
def mulacc(a, b, out_shape):
sum_axes = tuple(i for i,s in enumerate(out_shape) if s == 1)
a_axes, b_axes = get_input_axes(a, sum_axes), get_input_axes(b, sum_axes)
a_input, b_input = get_sliced_input(a, a_axes), get_sliced_input(b, b_axes)
out_axes = [i for i in range(len(out_shape)) if (i in a_axes or i in b_axes) and i not in sum_axes]
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out_axes)}", a_input, b_input)
return expand(ret.reshape(tuple(1 if (i not in a_axes and i not in b_axes) else s for i,s in enumerate(out_shape))), out_shape)
return mulacc
def as_strided(x, arg):
shape, stride, offset = arg
return np.ndarray(shape, x.dtype, buffer=np.require(x, requirements='C'), offset=offset*x.dtype.itemsize,
strides=tuple(y*x.dtype.itemsize for y in stride))
numpy_fxn_for_op: Dict[Op, Callable] = {
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.SQRT: np.sqrt, UnaryOps.NEG: np.negative,
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: np.less, BinaryOps.CMPEQ: np.equal, BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract,
BinaryOps.MUL: np.multiply, BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(x.dtype, copy=False), BinaryOps.XOR: np.bitwise_xor,
ReduceOps.SUM: lambda x, new_shape: x.sum(reduce_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
ReduceOps.MAX: lambda x, new_shape: x.max(reduce_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy(), optimize=True), lambda x: x.strides, np.broadcast_to),
TernaryOps.WHERE: np.where, MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: np.broadcast_to, MovementOps.PAD: np.pad
}
class NumpyAllocator(Allocator):
def _alloc(self, size:int): return np.empty(size, dtype=np.uint8)
def as_buffer(self, src:np.ndarray) -> memoryview: return flat_mv(np.require(src, requirements='C').data)
def copyin(self, dest:np.ndarray, src:memoryview): np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape))
def copyout(self, dest:memoryview, src:np.ndarray): np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src)
class CPUDevice(Interpreted):
def __init__(self, device:str): super().__init__(device, NumpyAllocator(), numpy_fxn_for_op)

View File

@@ -1,45 +0,0 @@
import torch
from typing import Dict, Callable
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
from tinygrad.device import Interpreted, Allocator
from tinygrad.dtype import dtypes
from tinygrad.helpers import getenv, flatten
from tinygrad.runtime.ops_cpu import einsum_mulacc, reduce_axis
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
type_map = {torch.bool: dtypes.bool,
torch.int8: dtypes.int8, torch.uint8: dtypes.uint8, torch.int16: dtypes.int16, torch.int32: dtypes.int32, torch.int64: dtypes.int64,
torch.float16: dtypes.float16, torch.bfloat16: dtypes.bfloat16, torch.float32: dtypes.float32, torch.float64: dtypes.float64}
inverse_type_map = {v: k for k,v in type_map.items()}
# TODO: should unsupported types fail instead of implicit conversion?
inverse_type_map.update({dtypes.uint16: torch.int16, dtypes.uint32: torch.int32, dtypes.uint64: torch.int64})
def as_strided(x, arg):
shape, stride, offset = arg
x = x.contiguous()
offset += x.storage_offset() # NOTE: contiguous can still have a storage_offset, so we adjust for it
if any(i < 0 for i in stride):
return torch.as_strided(x, shape, tuple(abs(i) for i in stride),
offset + sum((s-1)*a if a < 0 else 0 for (s,a) in zip(shape, stride))).flip([i for i,a in enumerate(stride) if a < 0])
return torch.as_strided(x, shape, stride, offset)
torch_fxn_for_op: Dict[Op, Callable] = {
BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
UnaryOps.EXP2: torch.exp2, UnaryOps.LOG2: torch.log2, UnaryOps.SIN: torch.sin, UnaryOps.SQRT: torch.sqrt, UnaryOps.NEG: torch.neg,
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]),
BinaryOps.ADD: torch.add, BinaryOps.SUB: torch.sub, BinaryOps.MUL: torch.mul, BinaryOps.DIV: lambda x,y: torch.div(x, y).type(x.dtype),
BinaryOps.XOR: torch.bitwise_xor, BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: torch.lt, BinaryOps.CMPEQ: torch.eq,
ReduceOps.SUM: lambda x, new_shape: x.sum(reduce_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
ReduceOps.MAX: lambda x, new_shape: x.amax(reduce_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s)),
TernaryOps.WHERE: torch.where, MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: lambda x, arg: x.expand(arg),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, flatten(padding[::-1])),
}
class TorchAllocator(Allocator):
def _alloc(self, size:int): return torch.empty([size], device=device, dtype=torch.uint8)
def copyin(self, dest:torch.Tensor, src:memoryview): dest.copy_(torch.frombuffer(src, dtype=dest.dtype))
def copyout(self, dest:memoryview, src:torch.Tensor): torch.frombuffer(dest, dtype=src.dtype).copy_(src.flatten())
class TorchDevice(Interpreted):
def __init__(self, device:str): super().__init__(device, TorchAllocator(), torch_fxn_for_op)