onnx ops cleanup (#2413)

* onnx ops cleanup

* revert those
This commit is contained in:
George Hotz
2023-11-23 18:39:49 -08:00
committed by GitHub
parent 8f89e21fca
commit 12023b6824
3 changed files with 22 additions and 41 deletions

View File

@@ -147,7 +147,9 @@ def get_run_onnx(onnx_model: ModelProto):
# NOTE some ops live here because they require access to some local variables
# have to use n.output for cases when num_outputs is absent
if n.op_type == "Split":
if n.op_type in onnx_ops.tensor_methods:
ret = getattr(Tensor, n.op_type.lower())(*inp, **opt)
elif n.op_type == "Split":
axis = opt.get("axis", 0)
split = None if len(inp) == 1 else [int(x) for x in safe_numpy(inp[1])]
if split is None:

View File

@@ -10,32 +10,16 @@ import functools
from typing import Union, Tuple, Optional, List, Any
import math
tensor_methods = {"Neg", "Reciprocal", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "Tanh", "MatMul",
"Floor", "Ceil", "Tanh", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Softsign", "Asinh", "Acosh", "Atanh"}
# **************** Free Ops ****************
def Identity(input: Tensor): return input
def Neg(input: Tensor): return -input
def Add(input: Tensor, other: Tensor, broadcast=None): return input + other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input + other).cast(input.dtype)
def Sub(input: Union[Tensor, Any], other: Tensor): return input - other # some test has input as int
def Mul(input: Tensor, other: Tensor): return (input * other) if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input * other).cast(input.dtype)
# in openpilot, due to SHUFFLE_PAD_OPS issues, we are spending an extra kernel
def Div(input: Tensor, other: Tensor): return input / other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else input.div(other).floor()
def Pow(input: Tensor, other: Tensor): return (input.float() ** other.float()).cast(input.dtype)
def Reciprocal(input: Tensor): return input.reciprocal()
def Sqrt(input: Tensor): return input.sqrt()
def Sign(input: Tensor): return input.sign()
def Abs(input: Tensor): return input.abs()
def Exp(input: Tensor): return input.exp()
def Log(input: Tensor): return input.log()
def Mish(input: Tensor): return input.mish()
def Sin(x: Tensor): return x.sin()
def Cos(x: Tensor): return x.cos()
def Tan(x: Tensor): return x.tan()
def Relu(input: Tensor): return input.relu()
def Sigmoid(input: Tensor): return input.sigmoid()
def Tanh(input: Tensor): return input.tanh()
def MatMul(input: Tensor, other: Tensor): return input.matmul(other)
def Floor(x:Tensor): return x.floor()
def Ceil(x:Tensor): return x.ceil()
def Div(input: Tensor, other: Tensor): return input / other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else input.div(other).floor() # TODO: this has dtype issues
def Pow(input: Tensor, other: Tensor): return (input.float() ** other.float()).cast(input.dtype) # TODO: this has dtype issues
def Less(x:Tensor,y:Tensor): return (x<y).cast(dtypes.bool)
def LessOrEqual(x:Tensor,y:Tensor): return (x<=y).cast(dtypes.bool)
def Greater(x:Tensor,y:Tensor): return (x>y).cast(dtypes.bool)
@@ -45,7 +29,6 @@ def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0): return functools.reduce(Tensor.__add__, data_0)
def Mean(*data_0): return functools.reduce(Tensor.__add__, data_0) / len(data_0)
def Where(condition:Tensor,X:Tensor,Y:Tensor): return condition.where(X, Y).cast(X.dtype)
def Cast(input: Tensor, to): return input.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to)))
# **************** Simple Ops ****************
@@ -58,17 +41,10 @@ def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=
elif value_ints: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
elif value_string or value_strings: raise NotImplementedError(f'value_string or value_strings not implemented for Constant op')
def Softsign(input: Tensor): return input / (1+input.abs())
def Cosh(x): return (math.e ** x + math.e ** -x) / 2
def Sinh(x): return (math.e ** x - math.e ** -x) / 2
def Tanh(x): return x.tanh()
def HardSigmoid(input: Tensor, alpha=0.2, beta=0.5): return (alpha*input + beta).clip(0, 1)
def HardSwish(input: Tensor): return input * HardSigmoid(input, 1/6, 0.5)
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x/math.sqrt(2)))
def Celu(X: Tensor, alpha=1.0): return X.relu() - (-alpha*(X/alpha).exp()+1).relu()
def Celu(x:Tensor, alpha=1.0): return x.celu(alpha)
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
def Softplus(X: Tensor): return X.softplus()
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
return X.clip(0, float("inf")) + X.clip(float("-inf"), 0) * slope
@@ -104,15 +80,12 @@ def Size(data: Tensor): return prod(data if isinstance(data, list) else data.sha
def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)
def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))])
def Shrink(input: Tensor, bias=0.0, lambd=0.5): return (input < -lambd)*(input+bias) + (input > lambd)*(input-bias)
def And(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.zeros(*x.shape)).cast(dtypes.bool)
def Or(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.ones(*x.shape)).cast(dtypes.bool)
def Xor(x:Tensor, y:Tensor): return Where((x==y), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
def Not(x:Tensor): return Where((x==1), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
def And(x:Tensor, y:Tensor): return (x==y).where(x, Tensor.zeros(*x.shape)).cast(dtypes.bool)
def Or(x:Tensor, y:Tensor): return (x==y).where(x, Tensor.ones(*x.shape)).cast(dtypes.bool)
def Xor(x:Tensor, y:Tensor): return (x==y).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
def Not(x:Tensor): return (x==1).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
def Asin(x): return Atan(x / Tensor.sqrt(1 - x * x))
def Asinh(x): return Tensor.log(x + Tensor.sqrt(x * x + 1))
def Acosh(x): return Tensor.log(x + Tensor.sqrt(x * x - 1))
def Atanh(x): return 0.5 * Tensor.log((1 + x)/(1 - x))
def Acos(x: Tensor):
negate = (x < 0)
x = x.abs()
@@ -180,7 +153,6 @@ def Expand(input: Tensor, shape):
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x_shape, y_shape))
return input.reshape(x_shape).expand(shape_ret)
# **************** Complex Ops ****************
def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):

View File

@@ -603,7 +603,7 @@ class Tensor:
# ***** mlops (unary) *****
def __neg__(self): return mlops.Neg.apply(self)
def neg(self): return mlops.Neg.apply(self)
def contiguous(self): return mlops.Contiguous.apply(self)
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
def log(self): return mlops.Log.apply(self)
@@ -639,6 +639,11 @@ class Tensor:
def relu6(self): return self.relu() - (self-6).relu()
def hardswish(self): return self * (self+3).relu6() * (1/6)
def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def sinh(self): return (self.exp() - self.neg().exp()) / 2
def cosh(self): return (self.exp() + self.neg().exp()) / 2
def atanh(self): return ((1 + self)/(1 - self)).log() / 2
def asinh(self): return (self + (self.square() + 1).sqrt()).log()
def acosh(self): return (self + (self.square() - 1).sqrt()).log()
def hardtanh(self, min_val=-1, max_val=1): return self.clip(min_val, max_val)
def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
def quick_gelu(self): return self * (self * 1.702).sigmoid()
@@ -716,7 +721,9 @@ class Tensor:
x,z = x_._broadcasted(other)
return mlops.Where.apply(x, *y._broadcasted(z))
# ***** binary op wrappers (18 wasted lines to make the typechecker happy) *****
# ***** op wrappers (wasted lines to make the typechecker happy) *****
def __neg__(self) -> Tensor: return self.neg()
def __add__(self, x) -> Tensor: return self.add(x)
def __sub__(self, x) -> Tensor: return self.sub(x)