mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
remove output_type in ops_cpu and ops_torch (#2892)
now the input types are matched and checked in lazy, we can remove these output_type. also remove the usage of least_upper_dtype in ops.py since we can just use the input type
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
from typing import Callable, Dict, Tuple
|
||||
from tinygrad.helpers import dtypes, flat_mv
|
||||
from tinygrad.helpers import flat_mv
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
|
||||
from tinygrad.device import Interpreted, Allocator
|
||||
|
||||
@@ -8,9 +8,6 @@ def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[in
|
||||
assert len(in_shape) == len(out_shape), "reduce shapes must have same dimensions"
|
||||
return tuple(i for i,(a,b) in enumerate(zip(in_shape, out_shape)) if a != b)
|
||||
|
||||
# TODO: this should be global infrastructure
|
||||
def output_type(x, y): return x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
|
||||
|
||||
def einsum_mulacc(einsum, get_strides, expand):
|
||||
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
|
||||
def axes_slice(strides): return tuple(i for i,s in enumerate(strides) if s != 0), tuple(slice(None) if s != 0 else 0 for s in strides)
|
||||
@@ -33,7 +30,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: x<y,
|
||||
BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract, BinaryOps.MUL: np.multiply,
|
||||
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(output_type(x, y), copy=False),
|
||||
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(x.dtype, copy=False),
|
||||
BinaryOps.XOR: np.bitwise_xor,
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy(), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
TernaryOps.WHERE: np.where,
|
||||
|
||||
Reference in New Issue
Block a user