mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
write sqrt and div using pow
This commit is contained in:
@@ -31,6 +31,20 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-7, grad_atol=1e-7):
|
||||
print("testing %30r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms" % (shps, torch_fp, tinygrad_fp, torch_fbp-torch_fp, tinygrad_fbp-tinygrad_fp))
|
||||
|
||||
class TestOps(unittest.TestCase):
|
||||
def test_add(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add)
|
||||
def test_sub(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub)
|
||||
def test_mul(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul)
|
||||
def test_div(self):
|
||||
# TODO: why does this need more tolerance?
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, atol=5e-5)
|
||||
def test_pow(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow)
|
||||
def test_sqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt)
|
||||
|
||||
def test_conv2d(self):
|
||||
for bs in [1,8]:
|
||||
for cin in [1,3]:
|
||||
|
||||
@@ -3,18 +3,6 @@ from .tensor import Function, register
|
||||
|
||||
# ************* basic ops *************
|
||||
|
||||
class Mul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return x*y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
return y*grad_output, x*grad_output
|
||||
register('mul', Mul)
|
||||
|
||||
class Add(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
@@ -36,30 +24,30 @@ class Sub(Function):
|
||||
return grad_output, -grad_output
|
||||
register('sub', Sub)
|
||||
|
||||
class Div(Function):
|
||||
class Mul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return x/y
|
||||
return x*y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# this right?
|
||||
x,y = ctx.saved_tensors
|
||||
return y/grad_output, x/grad_output
|
||||
register('div', Div)
|
||||
return y*grad_output, x*grad_output
|
||||
register('mul', Mul)
|
||||
|
||||
class Sqrt(Function):
|
||||
class Pow(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return np.sqrt(x)
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return x ** y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise Exception("write this")
|
||||
register('sqrt', Sqrt)
|
||||
|
||||
x,y = ctx.saved_tensors
|
||||
return y * (x**(y-1.0)) * grad_output, (x**y) * np.log(x) * grad_output
|
||||
register('pow', Pow)
|
||||
|
||||
class Dot(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
|
||||
@@ -10,13 +10,12 @@ class Tensor:
|
||||
#print(type(data), data)
|
||||
if type(data) == list:
|
||||
data = np.array(data, dtype=np.float32)
|
||||
if type(data) != np.ndarray:
|
||||
elif type(data) != np.ndarray:
|
||||
print("error constructing tensor with %r" % data)
|
||||
assert(False)
|
||||
if data.dtype == np.float64:
|
||||
#print("are you sure you want float64 in %r?" % data)
|
||||
pass
|
||||
self.data = data
|
||||
|
||||
# only float32
|
||||
self.data = data.astype(np.float32)
|
||||
self.grad = None
|
||||
|
||||
# internal variables used for autograd graph construction
|
||||
@@ -67,10 +66,20 @@ class Tensor:
|
||||
t.grad = g
|
||||
t.backward(False)
|
||||
|
||||
# ***** non first class ops *****
|
||||
|
||||
def mean(self):
|
||||
div = Tensor(np.array([1/self.data.size], dtype=self.data.dtype))
|
||||
return self.sum().mul(div)
|
||||
|
||||
def sqrt(self):
|
||||
root = Tensor(np.zeros(self.shape)+0.5)
|
||||
return self.pow(root)
|
||||
|
||||
def div(self, y):
|
||||
root = Tensor(np.zeros(self.shape)-1)
|
||||
return self.mul(y.pow(root))
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
def __init__(self, *tensors):
|
||||
|
||||
Reference in New Issue
Block a user