diff --git a/accel/opencl/ops_opencl.py b/accel/opencl/ops_opencl.py index db99418ead..a643a58a15 100644 --- a/accel/opencl/ops_opencl.py +++ b/accel/opencl/ops_opencl.py @@ -101,7 +101,10 @@ def get_getters(ewbufs, ret): def roundup(x, n=4): return (x+(n-1))//n * n class OpenCLBuffer(GPUBuffer): code_for_op = { - UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)", UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)", UnaryOps.SIGN: "sign(A)", + UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.SIGN: "sign(A)", + UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)", + UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)", + UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "(1.0/A)", BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)", ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)" } diff --git a/openpilot/compile.py b/openpilot/compile.py index 132219217c..9b8fca77de 100644 --- a/openpilot/compile.py +++ b/openpilot/compile.py @@ -172,7 +172,7 @@ def compile(input, output_fn): try: from test.test_onnx import run_onnx_torch torch_out = run_onnx_torch(onnx_model, np_inputs).numpy() - print(tinygrad_out_np, torch_out) + print(tinygrad_out_np, torch_out, "mse", np.sum((tinygrad_out_np-torch_out)**2)) np.testing.assert_allclose(torch_out, tinygrad_out_np, atol=1e-4, rtol=1e-2) except ModuleNotFoundError: pass diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index f5695aa28b..cdbde32fdd 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -5,7 +5,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, Processing class CPUBuffer(np.ndarray): fxn_for_op = { UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.RELU: lambda x: x.relu(), - UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.SIGN: lambda x: x.sign(), + UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.SIGN: lambda x: x.sign(), UnaryOps.RECIPROCAL: lambda x: 1.0/x, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, BinaryOps.CMPEQ: lambda x,y: (x==y).float() } diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 92983654e9..23df087dff 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -66,7 +66,7 @@ class CLProgram: class GPUBuffer: code_for_op = { - UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", + UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "(1.0/A)", BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)", ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)" } diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 6765663f92..7301b66e3f 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -29,6 +29,15 @@ class Exp(Function): def backward(self, grad_output): return self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) +class Reciprocal(Function): + def forward(self, x): + ret = x.unary_op(UnaryOps.RECIPROCAL) + self.save_for_backward(ret) + return ret + + def backward(self, grad_output): + return grad_output.unary_op(UnaryOps.NEG).binary_op(BinaryOps.MUL, self.saved_tensors[0]).binary_op(BinaryOps.MUL, self.saved_tensors[0]) + # TODO: add Neg? confirm the optimizer on Sub good enough # ************* reduce ops ************* @@ -89,9 +98,6 @@ class Mul(Function): grad_y = self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None return grad_x, grad_y -# TODO: add Div? is the optimizer on Pow good enough? -# nope, we def need div, can't optimize that - class Pow(Function): def forward(self, x, y): ret = x.binary_op(BinaryOps.POW, y) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 50a8142fb4..733533a938 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -10,7 +10,7 @@ from tinygrad.shapetracker import ShapeTracker sys.setrecursionlimit(10000) # these are the llops your accelerator must implement, along with toCpu -UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"]) +UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN", "RECIPROCAL"]) BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"]) ReduceOps = Enum("ReduceOps", ["SUM", "MAX"]) MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "EXPAND", "FLIP", "STRIDED", "PAD", "SHRINK"]) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0292a664db..1776ef58a4 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -241,7 +241,7 @@ class Tensor: # ***** activation functions (unary) ***** - def sigmoid(self): return (1.0 + (-self).exp()) ** -1.0 + def sigmoid(self): return (1.0 + (-self).exp()).reciprocal() # TODO: implement generic constant folding def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu() def swish(self): return self * self.sigmoid() @@ -268,9 +268,7 @@ class Tensor: def sub(self, x): return Tensor.broadcasted(Tensor._sub, self, x) def mul(self, x): return Tensor.broadcasted(Tensor._mul, self, x) def pow(self, x): return Tensor.broadcasted(Tensor._pow, self, x) - - # TODO: should be broadcasted binary op - def div(self, y): return self * (y ** -1.0) + def div(self, y): return self * y.reciprocal() __truediv__ = div # ***** functional nn ops *****