clean up ops, refactor pool backward. add stride test

This commit is contained in:
George Hotz
2020-10-26 08:47:11 -07:00
parent 93dceb4bee
commit 2a55d7402b
2 changed files with 64 additions and 37 deletions

View File

@@ -40,6 +40,18 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=2e-5, grad_atol=2e-6)
@unittest.skip("please write stride support")
def test_strided_conv2d(self):
bs = 4
cin = 3
H,W = 3,3
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=2).relu(), atol=2e-5, grad_atol=2e-6)
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,stride=(2,1)).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), atol=2e-5, grad_atol=2e-6)
def test_maxpool2x2(self):
helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, (2,2)), Tensor.max_pool2d)

View File

@@ -1,17 +1,7 @@
import numpy as np
from tinygrad.tensor import Function, register
class Reshape(Function):
@staticmethod
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
return x.reshape(shape)
@staticmethod
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return grad_output.reshape(in_shape), None
register('reshape', Reshape)
# ************* basic ops *************
class Mul(Function):
@staticmethod
@@ -35,19 +25,6 @@ class Add(Function):
return grad_output, grad_output
register('add', Add)
class ReLU(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return np.maximum(input, 0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output * (input >= 0)
return grad_input
register('relu', ReLU)
class Dot(Function):
@staticmethod
def forward(ctx, input, weight):
@@ -61,6 +38,7 @@ class Dot(Function):
grad_weight = grad_output.T.dot(input).T
return grad_input, grad_weight
register('dot', Dot)
register('matmul', Dot)
class Sum(Function):
@staticmethod
@@ -74,6 +52,34 @@ class Sum(Function):
return grad_output * np.ones_like(input)
register('sum', Sum)
# ************* nn ops *************
class ReLU(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return np.maximum(input, 0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output * (input >= 0)
return grad_input
register('relu', ReLU)
class Reshape(Function):
@staticmethod
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
return x.reshape(shape)
@staticmethod
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return grad_output.reshape(in_shape), None
register('reshape', Reshape)
class LogSoftmax(Function):
@staticmethod
def forward(ctx, input):
@@ -92,6 +98,8 @@ class LogSoftmax(Function):
register('logsoftmax', LogSoftmax)
# ************* conv ops *************
class Conv2D(Function):
@staticmethod
def forward(ctx, x, w):
@@ -124,6 +132,9 @@ class Conv2D(Function):
return dx, dw
register('conv2d', Conv2D)
# ************* pooling ops *************
def stack_for_pool(x, py, px):
my, mx = (x.shape[2]//py)*py, (x.shape[3]//px)*px
stack = []
@@ -133,6 +144,16 @@ def stack_for_pool(x, py, px):
stack.append(xup[:, :, Y::py, X::px][None])
return np.concatenate(stack, axis=0)
def unstack_for_pool(fxn, s, py, px):
my, mx = (s[2]//py)*py, (s[3]//px)*px
for Y in range(py):
for X in range(px):
ll = fxn(Y*px+X)
if X == 0 and Y == 0:
ret = np.zeros(s, dtype=ll.dtype)
ret[:, :, Y:my:py, X:mx:px] = ll
return ret
class MaxPool2D(Function):
@staticmethod
def forward(ctx, x, kernel_size=(2, 2)):
@@ -144,13 +165,9 @@ class MaxPool2D(Function):
@staticmethod
def backward(ctx, grad_output):
idxs,s = ctx.saved_tensors
py, px = ctx.kernel_size
my, mx = (s[2]//py)*py, (s[3]//px)*px
ret = np.zeros(s, dtype=grad_output.dtype)
for Y in range(py):
for X in range(px):
ret[:, :, Y:my:py, X:mx:px] = grad_output * (idxs == (Y*px+X))
return ret
return unstack_for_pool(
lambda idx: grad_output * (idxs == idx),
s, *ctx.kernel_size)
register('max_pool2d', MaxPool2D)
class AvgPool2D(Function):
@@ -164,10 +181,8 @@ class AvgPool2D(Function):
def backward(ctx, grad_output):
s, = ctx.saved_tensors
py, px = ctx.kernel_size
my, mx = (s[2]//py)*py, (s[3]//px)*px
ret = np.zeros(s, dtype=grad_output.dtype)
for Y in range(py):
for X in range(px):
ret[:, :, Y:my:py, X:mx:px] = grad_output/py/px
return ret
return unstack_for_pool(
lambda idx: grad_output/py/px,
s, py, px)
register('avg_pool2d', AvgPool2D)