diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 63e8cf7e9d..9635ff4b12 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -1,16 +1,26 @@ import numpy as np -from typing import ClassVar -from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, base_fxn_for_op +import operator +from typing import ClassVar, Callable, Dict +from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, ProcessingOps, GenericExecAST, Op +from tinygrad.helpers import shape_to_axis -specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({ +base_fxn_for_op : Dict[Op, Callable] = { + UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x, + BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, + ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], + ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], + MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)], +} + +numpy_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{ UnaryOps.RELU: lambda x: np.maximum(x, 0), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x), BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32), MovementOps.FLIP: lambda x, axis: np.flip(x, axis), MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: lambda x, padding: np.pad(x, padding), MovementOps.EXPAND: lambda x, new_shape: np.broadcast_to(x, new_shape), MovementOps.STRIDED: lambda x, arg: np.lib.stride_tricks.as_strided(x.ravel().reshape(x.shape), shape=[y[0] for y in arg], strides=[y[1]*x.dtype.itemsize for y in arg]) -}) +}} class CPUBuffer(GenericExecAST): - fxn_for_op : ClassVar = specialized_fxn_for_op + fxn_for_op : ClassVar = numpy_fxn_for_op @staticmethod def fromCPU(x): return CPUBuffer(x) diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 2e6f79658a..8e8c774f62 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -1,18 +1,19 @@ import torch -from typing import ClassVar, Final -from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, base_fxn_for_op +from typing import ClassVar, Final, Dict, Callable +from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, Op from tinygrad.helpers import getenv +from tinygrad.llops.ops_cpu import base_fxn_for_op -specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({ +torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{ UnaryOps.RELU: lambda x: x.relu(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), BinaryOps.CMPEQ: lambda x,y: (x==y).float(), MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), MovementOps.STRIDED: lambda x, arg: x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg]), ProcessingOps.CONV: lambda x,w,C: C.px == C.px_ and C.py == C.py_ and torch.conv2d(x, w, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px)) -}) +}} device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")) class TorchBuffer(GenericExecAST): - fxn_for_op : ClassVar = specialized_fxn_for_op + fxn_for_op : ClassVar = torch_fxn_for_op SUPPORTS_SIMPLE_PADDING : Final = True @staticmethod diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ad97eda799..1d03859c3e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -3,7 +3,7 @@ import numpy as np from enum import Enum, auto from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict import functools, operator -from tinygrad.helpers import prod, shape_to_axis +from tinygrad.helpers import prod from tinygrad.shape import ShapeTracker from tinygrad.helpers import getenv @@ -87,14 +87,6 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method return ret def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(ast, lambda x: GenericExecAST(GenericShape(x.shape))).buf -base_fxn_for_op = { - UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x, - BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, - ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], - ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], - MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)], -} - class GlobalCounters: global_ops : ClassVar[int] = 0 global_mem : ClassVar[int] = 0