remove match_type in ops_torch and ops_cpu (#2817)

* remove match_type in ops_torch and ops_cpu

input dtypes are aligned and casted in mlops

* dict union only after python3.9

* fix that

* fix Sigmoid forward cast
This commit is contained in:
chenyu
2023-12-17 15:32:30 -05:00
committed by GitHub
parent 887f3d9933
commit 91adb119b8
3 changed files with 26 additions and 27 deletions

View File

@@ -10,9 +10,6 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple
# TODO: this should be global infrastructure
def output_type(x, y): return x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
def match_types(x, y):
up = output_type(x, y)
return x.astype(up, copy=False), y.astype(up, copy=False)
def einsum_mulacc(einsum, get_strides, expand):
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
@@ -24,20 +21,23 @@ def einsum_mulacc(einsum, get_strides, expand):
return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
return mulacc
def as_strided(x, arg):
return np.ndarray(shape=arg[0], dtype=x.dtype, buffer=np.require(x, requirements='C'),
offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1]))
numpy_fxn_for_op: Dict[Op, Callable] = {
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False),
BinaryOps.XOR: lambda x, y: np.bitwise_xor(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)),
BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract, BinaryOps.MUL: np.multiply,
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(output_type(x, y), copy=False),
BinaryOps.XOR: np.bitwise_xor, UnaryOps.SQRT: np.sqrt,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
ReduceOps.MAX: lambda x, new_shape: x.max(shape_to_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])), # noqa: E501
MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
MovementOps.AS_STRIDED: as_strided, MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy(), optimize=True), lambda x: x.strides, np.broadcast_to),
TernaryOps.WHERE: np.where,
}