mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
clean up ops, refactor pool backward. add stride test
This commit is contained in:
@@ -40,6 +40,18 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w).relu(), atol=2e-5, grad_atol=2e-6)
|
||||
|
||||
@unittest.skip("please write stride support")
|
||||
def test_strided_conv2d(self):
|
||||
bs = 4
|
||||
cin = 3
|
||||
H,W = 3,3
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=2).relu(), atol=2e-5, grad_atol=2e-6)
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=(2,1)).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), atol=2e-5, grad_atol=2e-6)
|
||||
|
||||
def test_maxpool2x2(self):
|
||||
helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, (2,2)), Tensor.max_pool2d)
|
||||
|
||||
|
||||
@@ -1,17 +1,7 @@
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Function, register
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return x.reshape(shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return grad_output.reshape(in_shape), None
|
||||
register('reshape', Reshape)
|
||||
# ************* basic ops *************
|
||||
|
||||
class Mul(Function):
|
||||
@staticmethod
|
||||
@@ -35,19 +25,6 @@ class Add(Function):
|
||||
return grad_output, grad_output
|
||||
register('add', Add)
|
||||
|
||||
class ReLU(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return np.maximum(input, 0)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
grad_input = grad_output * (input >= 0)
|
||||
return grad_input
|
||||
register('relu', ReLU)
|
||||
|
||||
class Dot(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
@@ -61,6 +38,7 @@ class Dot(Function):
|
||||
grad_weight = grad_output.T.dot(input).T
|
||||
return grad_input, grad_weight
|
||||
register('dot', Dot)
|
||||
register('matmul', Dot)
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
@@ -74,6 +52,34 @@ class Sum(Function):
|
||||
return grad_output * np.ones_like(input)
|
||||
register('sum', Sum)
|
||||
|
||||
|
||||
# ************* nn ops *************
|
||||
|
||||
class ReLU(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return np.maximum(input, 0)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
grad_input = grad_output * (input >= 0)
|
||||
return grad_input
|
||||
register('relu', ReLU)
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return x.reshape(shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return grad_output.reshape(in_shape), None
|
||||
register('reshape', Reshape)
|
||||
|
||||
class LogSoftmax(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
@@ -92,6 +98,8 @@ class LogSoftmax(Function):
|
||||
register('logsoftmax', LogSoftmax)
|
||||
|
||||
|
||||
# ************* conv ops *************
|
||||
|
||||
class Conv2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w):
|
||||
@@ -124,6 +132,9 @@ class Conv2D(Function):
|
||||
return dx, dw
|
||||
register('conv2d', Conv2D)
|
||||
|
||||
|
||||
# ************* pooling ops *************
|
||||
|
||||
def stack_for_pool(x, py, px):
|
||||
my, mx = (x.shape[2]//py)*py, (x.shape[3]//px)*px
|
||||
stack = []
|
||||
@@ -133,6 +144,16 @@ def stack_for_pool(x, py, px):
|
||||
stack.append(xup[:, :, Y::py, X::px][None])
|
||||
return np.concatenate(stack, axis=0)
|
||||
|
||||
def unstack_for_pool(fxn, s, py, px):
|
||||
my, mx = (s[2]//py)*py, (s[3]//px)*px
|
||||
for Y in range(py):
|
||||
for X in range(px):
|
||||
ll = fxn(Y*px+X)
|
||||
if X == 0 and Y == 0:
|
||||
ret = np.zeros(s, dtype=ll.dtype)
|
||||
ret[:, :, Y:my:py, X:mx:px] = ll
|
||||
return ret
|
||||
|
||||
class MaxPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, kernel_size=(2, 2)):
|
||||
@@ -144,13 +165,9 @@ class MaxPool2D(Function):
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
idxs,s = ctx.saved_tensors
|
||||
py, px = ctx.kernel_size
|
||||
my, mx = (s[2]//py)*py, (s[3]//px)*px
|
||||
ret = np.zeros(s, dtype=grad_output.dtype)
|
||||
for Y in range(py):
|
||||
for X in range(px):
|
||||
ret[:, :, Y:my:py, X:mx:px] = grad_output * (idxs == (Y*px+X))
|
||||
return ret
|
||||
return unstack_for_pool(
|
||||
lambda idx: grad_output * (idxs == idx),
|
||||
s, *ctx.kernel_size)
|
||||
register('max_pool2d', MaxPool2D)
|
||||
|
||||
class AvgPool2D(Function):
|
||||
@@ -164,10 +181,8 @@ class AvgPool2D(Function):
|
||||
def backward(ctx, grad_output):
|
||||
s, = ctx.saved_tensors
|
||||
py, px = ctx.kernel_size
|
||||
my, mx = (s[2]//py)*py, (s[3]//px)*px
|
||||
ret = np.zeros(s, dtype=grad_output.dtype)
|
||||
for Y in range(py):
|
||||
for X in range(px):
|
||||
ret[:, :, Y:my:py, X:mx:px] = grad_output/py/px
|
||||
return ret
|
||||
return unstack_for_pool(
|
||||
lambda idx: grad_output/py/px,
|
||||
s, py, px)
|
||||
register('avg_pool2d', AvgPool2D)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user