update cpu and torch to hold buffers (#542)

* update cpu and torch to hold buffers

* save lines, and probably faster
This commit is contained in:
George Hotz
2023-02-08 09:40:45 -06:00
committed by GitHub
parent ae4f0aeb5f
commit 996e0a10b7
2 changed files with 40 additions and 40 deletions

View File

@@ -1,40 +1,35 @@
from __future__ import annotations
import operator
import numpy as np
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, GenericExecAST
from tinygrad.helpers import shape_to_axis
class CPUBuffer(np.ndarray, GenericExecAST):
fxn_for_op = {
UnaryOps.NOOP: lambda x: x[:].contiguous(), UnaryOps.NEG: lambda x: -x, UnaryOps.RELU: lambda x: x.relu(),
UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x,
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul,
BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: x.amax(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)]
}
base_fxn_for_op = {
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x,
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
}
def relu(x): return np.maximum(x, 0)
def exp(x): return np.exp(x)
def log(x): return np.log(x)
def float(x): return x.astype(np.float32)
def flip(x, axis): return np.flip(x, axis)
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
def permute(x, order): return x.transpose(order)
def pad(x, padding): return np.pad(x, padding).view(CPUBuffer)
def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer)
def strided(x, arg): return np.lib.stride_tricks.as_strided(x.ravel().reshape(x.shape), shape=[y[0] for y in arg], strides=[y[1]*x.dtype.itemsize for y in arg]).view(CPUBuffer)
class CPUBuffer(GenericExecAST):
fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
UnaryOps.RELU: lambda x: np.maximum(x, 0), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x), BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32),
MovementOps.FLIP: lambda x, axis: np.flip(x, axis), MovementOps.PERMUTE: lambda x, order: x.transpose(order),
MovementOps.PAD: lambda x, padding: np.pad(x, padding), MovementOps.EXPAND: lambda x, new_shape: np.broadcast_to(x, new_shape),
MovementOps.STRIDED: lambda x, arg: np.lib.stride_tricks.as_strided(x.ravel().reshape(x.shape), shape=[y[0] for y in arg], strides=[y[1]*x.dtype.itemsize for y in arg])
})
def __init__(self, lbuf:np.ndarray): self.buf, self.shape = lbuf, tuple(lbuf.shape)
@staticmethod
def fromCPU(x): return x.view(CPUBuffer)
def toCPU(x): return x
def fromCPU(x): return CPUBuffer(x)
def toCPU(x): return x.buf
def contiguous(x): return x.ravel().reshape(x.shape)
def unary_op(x, op): return CPUBuffer.fxn_for_op[op](x)
def binary_op(x, op, y): return CPUBuffer.fxn_for_op[op](x, y)
def reduce_op(x, op, new_shape): return CPUBuffer.fxn_for_op[op](x, new_shape)
def movement_op(x, op, arg=None): return CPUBuffer.fxn_for_op[op](x, arg) if op in CPUBuffer.fxn_for_op else getattr(x, op.name.lower())(arg)
def contiguous(x): return x.unary_op(UnaryOps.NOOP)
def unary_op(x, op): return type(x)(x.fxn_for_op[op](x.buf))
def binary_op(x, op, y): return type(x)(x.fxn_for_op[op](x.buf, y.buf))
def reduce_op(x, op, new_shape): return type(x)(x.fxn_for_op[op](x.buf, new_shape))
def movement_op(x, op, arg=None): return type(x)(x.fxn_for_op[op](x.buf, arg)) if op in x.fxn_for_op else type(x)(getattr(x.buf, op.name.lower())(arg))
def processing_op(x,op,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
@@ -42,6 +37,6 @@ class CPUBuffer(np.ndarray, GenericExecAST):
tx = x.movement_op(MovementOps.STRIDED, (
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),
(C.oy, C.sy*x.shape[3]), (C.ox, C.sx), (C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
out = np.einsum("nGhwCHW, GkCHW -> nGkhw", tx.ravel().reshape(tx.shape), tw.ravel().reshape(tw.shape))
return out.reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer)
tw = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W))
out = np.einsum("nGhwCHW, GkCHW -> nGkhw", tx.buf.ravel().reshape(tx.shape), tw.buf.ravel().reshape(tw.shape))
return CPUBuffer(out.reshape(C.bs, C.groups*C.rcout, C.oy, C.ox))

View File

@@ -1,21 +1,26 @@
import torch
from tinygrad.llops.ops_cpu import CPUBuffer # type: ignore
from tinygrad.ops import ProcessingOps, GenericExecAST
from tinygrad.llops.ops_cpu import base_fxn_for_op, CPUBuffer # type: ignore
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST
from tinygrad.helpers import getenv
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
class TorchBuffer(torch.Tensor, GenericExecAST):
def pad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
def strided(x, arg): return x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg])
class TorchBuffer(GenericExecAST):
fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
UnaryOps.RELU: lambda x: x.relu(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
MovementOps.STRIDED: lambda x, arg: x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg])
})
def __init__(self, lbuf:torch.Tensor): self.buf, self.shape = lbuf, tuple(lbuf.shape)
@staticmethod
def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False)).to(device)
def toCPU(x): return x.cpu().numpy()
def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False).to(device))
def toCPU(x): return x.buf.cpu().numpy()
unary_op, binary_op, reduce_op, movement_op = CPUBuffer.unary_op, CPUBuffer.binary_op, CPUBuffer.reduce_op, CPUBuffer.movement_op
contiguous, unary_op, binary_op, reduce_op, movement_op = CPUBuffer.contiguous, CPUBuffer.unary_op, CPUBuffer.binary_op, CPUBuffer.reduce_op, CPUBuffer.movement_op
SUPPORTS_SIMPLE_PADDING = True
def processing_op(x,op,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
assert C.px == C.px_ and C.py == C.py_, "asymmetric padding in conv is not supported"
return torch.conv2d(x, w, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))
return TorchBuffer(torch.conv2d(x.buf, w.buf, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px)))