mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
remove match_type in ops_torch and ops_cpu (#2817)
* remove match_type in ops_torch and ops_cpu input dtypes are aligned and casted in mlops * dict union only after python3.9 * fix that * fix Sigmoid forward cast
This commit is contained in:
@@ -10,9 +10,6 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple
|
||||
|
||||
# TODO: this should be global infrastructure
|
||||
def output_type(x, y): return x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
|
||||
def match_types(x, y):
|
||||
up = output_type(x, y)
|
||||
return x.astype(up, copy=False), y.astype(up, copy=False)
|
||||
|
||||
def einsum_mulacc(einsum, get_strides, expand):
|
||||
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
|
||||
@@ -24,20 +21,23 @@ def einsum_mulacc(einsum, get_strides, expand):
|
||||
return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
|
||||
return mulacc
|
||||
|
||||
def as_strided(x, arg):
|
||||
return np.ndarray(shape=arg[0], dtype=x.dtype, buffer=np.require(x, requirements='C'),
|
||||
offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1]))
|
||||
|
||||
numpy_fxn_for_op: Dict[Op, Callable] = {
|
||||
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
|
||||
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
|
||||
UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
||||
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False),
|
||||
BinaryOps.XOR: lambda x, y: np.bitwise_xor(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)),
|
||||
BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract, BinaryOps.MUL: np.multiply,
|
||||
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(output_type(x, y), copy=False),
|
||||
BinaryOps.XOR: np.bitwise_xor, UnaryOps.SQRT: np.sqrt,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
||||
ReduceOps.MAX: lambda x, new_shape: x.max(shape_to_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])), # noqa: E501
|
||||
MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
MovementOps.AS_STRIDED: as_strided, MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy(), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
TernaryOps.WHERE: np.where,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user