move base_fxn_for_op to ops_cpu

This commit is contained in:
George Hotz
2023-02-08 18:23:48 -06:00
parent c642f5e72b
commit 16a7edc775
3 changed files with 22 additions and 19 deletions

View File

@@ -1,16 +1,26 @@
import numpy as np
from typing import ClassVar
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, base_fxn_for_op
import operator
from typing import ClassVar, Callable, Dict
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, ProcessingOps, GenericExecAST, Op
from tinygrad.helpers import shape_to_axis
specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
base_fxn_for_op : Dict[Op, Callable] = {
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x,
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
}
numpy_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.RELU: lambda x: np.maximum(x, 0), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x), BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32),
MovementOps.FLIP: lambda x, axis: np.flip(x, axis), MovementOps.PERMUTE: lambda x, order: x.transpose(order),
MovementOps.PAD: lambda x, padding: np.pad(x, padding), MovementOps.EXPAND: lambda x, new_shape: np.broadcast_to(x, new_shape),
MovementOps.STRIDED: lambda x, arg: np.lib.stride_tricks.as_strided(x.ravel().reshape(x.shape), shape=[y[0] for y in arg], strides=[y[1]*x.dtype.itemsize for y in arg])
})
}}
class CPUBuffer(GenericExecAST):
fxn_for_op : ClassVar = specialized_fxn_for_op
fxn_for_op : ClassVar = numpy_fxn_for_op
@staticmethod
def fromCPU(x): return CPUBuffer(x)

View File

@@ -1,18 +1,19 @@
import torch
from typing import ClassVar, Final
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, base_fxn_for_op
from typing import ClassVar, Final, Dict, Callable
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, Op
from tinygrad.helpers import getenv
from tinygrad.llops.ops_cpu import base_fxn_for_op
specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.RELU: lambda x: x.relu(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
MovementOps.STRIDED: lambda x, arg: x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg]),
ProcessingOps.CONV: lambda x,w,C: C.px == C.px_ and C.py == C.py_ and torch.conv2d(x, w, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))
})
}}
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
class TorchBuffer(GenericExecAST):
fxn_for_op : ClassVar = specialized_fxn_for_op
fxn_for_op : ClassVar = torch_fxn_for_op
SUPPORTS_SIMPLE_PADDING : Final = True
@staticmethod

View File

@@ -3,7 +3,7 @@ import numpy as np
from enum import Enum, auto
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict
import functools, operator
from tinygrad.helpers import prod, shape_to_axis
from tinygrad.helpers import prod
from tinygrad.shape import ShapeTracker
from tinygrad.helpers import getenv
@@ -87,14 +87,6 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
return ret
def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(ast, lambda x: GenericExecAST(GenericShape(x.shape))).buf
base_fxn_for_op = {
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x,
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
}
class GlobalCounters:
global_ops : ClassVar[int] = 0
global_mem : ClassVar[int] = 0