diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 79b693b1f3..b0cf7a58be 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -2,14 +2,14 @@ from __future__ import annotations from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Dict, Callable import functools from enum import Enum, auto -from tinygrad.helpers import prod, DType, least_upper_dtype, dedup +from tinygrad.helpers import prod, DType, dedup from tinygrad.shape.symbolic import Variable from dataclasses import dataclass # these are the llops your accelerator must implement, along with toCpu # the Enum class doesn't work with mypy, this is static. sorry it's ugly # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars -# NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block +# NOTE: rdna3 only has RECIP and not DIV. DIV is on the chopping block class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); XOR = auto() # noqa: E702 class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702 @@ -80,9 +80,9 @@ InterpretedFlopCounter: Dict[Op, Callable] = { BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.size()}), # noqa: E501 UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops **{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, # noqa: E501 - **{op:lambda self,y: FlopCounter(self.shape, least_upper_dtype(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501 + **{op:lambda self,y: FlopCounter(self.shape, self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501 **{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, - TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, least_upper_dtype(y.dtype, z.dtype), self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501 + TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501 @functools.lru_cache(None) def get_lazyop_info(ast:LazyOp) -> FlopCounter: diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 8a0a20cd9d..c4ec44bae1 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -1,6 +1,6 @@ import numpy as np from typing import Callable, Dict, Tuple -from tinygrad.helpers import dtypes, flat_mv +from tinygrad.helpers import flat_mv from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op from tinygrad.device import Interpreted, Allocator @@ -8,9 +8,6 @@ def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[in assert len(in_shape) == len(out_shape), "reduce shapes must have same dimensions" return tuple(i for i,(a,b) in enumerate(zip(in_shape, out_shape)) if a != b) -# TODO: this should be global infrastructure -def output_type(x, y): return x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype - def einsum_mulacc(einsum, get_strides, expand): def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x]) def axes_slice(strides): return tuple(i for i,s in enumerate(strides) if s != 0), tuple(slice(None) if s != 0 else 0 for s in strides) @@ -33,7 +30,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = { UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x), BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: x