mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
add reciprocal
This commit is contained in:
@@ -101,7 +101,10 @@ def get_getters(ewbufs, ret):
|
||||
def roundup(x, n=4): return (x+(n-1))//n * n
|
||||
class OpenCLBuffer(GPUBuffer):
|
||||
code_for_op = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)", UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)", UnaryOps.SIGN: "sign(A)",
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.SIGN: "sign(A)",
|
||||
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
|
||||
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
|
||||
UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "(1.0/A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
|
||||
}
|
||||
|
||||
@@ -172,7 +172,7 @@ def compile(input, output_fn):
|
||||
try:
|
||||
from test.test_onnx import run_onnx_torch
|
||||
torch_out = run_onnx_torch(onnx_model, np_inputs).numpy()
|
||||
print(tinygrad_out_np, torch_out)
|
||||
print(tinygrad_out_np, torch_out, "mse", np.sum((tinygrad_out_np-torch_out)**2))
|
||||
np.testing.assert_allclose(torch_out, tinygrad_out_np, atol=1e-4, rtol=1e-2)
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, Processing
|
||||
class CPUBuffer(np.ndarray):
|
||||
fxn_for_op = {
|
||||
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.RELU: lambda x: x.relu(),
|
||||
UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.SIGN: lambda x: x.sign(),
|
||||
UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.SIGN: lambda x: x.sign(), UnaryOps.RECIPROCAL: lambda x: 1.0/x,
|
||||
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul,
|
||||
BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, BinaryOps.CMPEQ: lambda x,y: (x==y).float()
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ class CLProgram:
|
||||
|
||||
class GPUBuffer:
|
||||
code_for_op = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)",
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "(1.0/A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
|
||||
}
|
||||
|
||||
@@ -29,6 +29,15 @@ class Exp(Function):
|
||||
def backward(self, grad_output):
|
||||
return self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output)
|
||||
|
||||
class Reciprocal(Function):
|
||||
def forward(self, x):
|
||||
ret = x.unary_op(UnaryOps.RECIPROCAL)
|
||||
self.save_for_backward(ret)
|
||||
return ret
|
||||
|
||||
def backward(self, grad_output):
|
||||
return grad_output.unary_op(UnaryOps.NEG).binary_op(BinaryOps.MUL, self.saved_tensors[0]).binary_op(BinaryOps.MUL, self.saved_tensors[0])
|
||||
|
||||
# TODO: add Neg? confirm the optimizer on Sub good enough
|
||||
|
||||
# ************* reduce ops *************
|
||||
@@ -89,9 +98,6 @@ class Mul(Function):
|
||||
grad_y = self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
||||
return grad_x, grad_y
|
||||
|
||||
# TODO: add Div? is the optimizer on Pow good enough?
|
||||
# nope, we def need div, can't optimize that
|
||||
|
||||
class Pow(Function):
|
||||
def forward(self, x, y):
|
||||
ret = x.binary_op(BinaryOps.POW, y)
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.shapetracker import ShapeTracker
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"])
|
||||
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN", "RECIPROCAL"])
|
||||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
|
||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "EXPAND", "FLIP", "STRIDED", "PAD", "SHRINK"])
|
||||
|
||||
@@ -241,7 +241,7 @@ class Tensor:
|
||||
|
||||
# ***** activation functions (unary) *****
|
||||
|
||||
def sigmoid(self): return (1.0 + (-self).exp()) ** -1.0
|
||||
def sigmoid(self): return (1.0 + (-self).exp()).reciprocal()
|
||||
# TODO: implement generic constant folding
|
||||
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
|
||||
def swish(self): return self * self.sigmoid()
|
||||
@@ -268,9 +268,7 @@ class Tensor:
|
||||
def sub(self, x): return Tensor.broadcasted(Tensor._sub, self, x)
|
||||
def mul(self, x): return Tensor.broadcasted(Tensor._mul, self, x)
|
||||
def pow(self, x): return Tensor.broadcasted(Tensor._pow, self, x)
|
||||
|
||||
# TODO: should be broadcasted binary op
|
||||
def div(self, y): return self * (y ** -1.0)
|
||||
def div(self, y): return self * y.reciprocal()
|
||||
__truediv__ = div
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
Reference in New Issue
Block a user