remove output_type in ops_cpu and ops_torch (#2892)

now the input types are matched and checked in lazy, we can remove these output_type.
also remove the usage of least_upper_dtype in ops.py since we can just use the input type
This commit is contained in:
chenyu
2023-12-21 02:11:27 -05:00
committed by GitHub
parent 2d2c4980fe
commit 1500aca43d
3 changed files with 9 additions and 14 deletions

View File

@@ -1,6 +1,6 @@
import numpy as np
from typing import Callable, Dict, Tuple
from tinygrad.helpers import dtypes, flat_mv
from tinygrad.helpers import flat_mv
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
from tinygrad.device import Interpreted, Allocator
@@ -8,9 +8,6 @@ def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[in
assert len(in_shape) == len(out_shape), "reduce shapes must have same dimensions"
return tuple(i for i,(a,b) in enumerate(zip(in_shape, out_shape)) if a != b)
# TODO: this should be global infrastructure
def output_type(x, y): return x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
def einsum_mulacc(einsum, get_strides, expand):
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
def axes_slice(strides): return tuple(i for i,s in enumerate(strides) if s != 0), tuple(slice(None) if s != 0 else 0 for s in strides)
@@ -33,7 +30,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = {
UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: x<y,
BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract, BinaryOps.MUL: np.multiply,
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(output_type(x, y), copy=False),
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(x.dtype, copy=False),
BinaryOps.XOR: np.bitwise_xor,
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy(), optimize=True), lambda x: x.strides, np.broadcast_to),
TernaryOps.WHERE: np.where,