mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix kernel_size bug, name like torch, add test
This commit is contained in:
@@ -32,7 +32,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-7, grad_atol=1e-7):
|
||||
|
||||
class TestOps(unittest.TestCase):
|
||||
def test_conv2d(self):
|
||||
for bs in [1,128]:
|
||||
for bs in [1,8]:
|
||||
for cin in [1,3]:
|
||||
for H in [2,5]:
|
||||
for W in [2,3,5]:
|
||||
@@ -43,6 +43,12 @@ class TestOps(unittest.TestCase):
|
||||
def test_maxpool2x2(self):
|
||||
helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, (2,2)), Tensor.max_pool2d)
|
||||
|
||||
def test_maxpool_sizes(self):
|
||||
for sz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
helper_test_op([(32,2,110,28)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=sz),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=sz))
|
||||
|
||||
def test_avgpool2x2(self):
|
||||
helper_test_op([(32,2,111,28)], lambda x: torch.nn.functional.avg_pool2d(x, (2,2)), Tensor.avg_pool2d)
|
||||
|
||||
|
||||
@@ -130,13 +130,13 @@ def stack_for_pool(x, py, px):
|
||||
xup = x[:, :, :my, :mx]
|
||||
for Y in range(py):
|
||||
for X in range(px):
|
||||
stack.append(xup[:, :, Y::2, X::2][None])
|
||||
stack.append(xup[:, :, Y::py, X::px][None])
|
||||
return np.concatenate(stack, axis=0)
|
||||
|
||||
class MaxPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, pool_size=(2, 2)):
|
||||
stack = stack_for_pool(x, *pool_size)
|
||||
def forward(ctx, x, kernel_size=(2, 2)):
|
||||
stack = stack_for_pool(x, *kernel_size)
|
||||
idxs = np.argmax(stack, axis=0)
|
||||
ctx.save_for_backward(idxs, x.shape)
|
||||
return np.max(stack, axis=0)
|
||||
@@ -144,7 +144,7 @@ class MaxPool2D(Function):
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
idxs,s = ctx.saved_tensors
|
||||
py, px = ctx.pool_size
|
||||
py, px = ctx.kernel_size
|
||||
my, mx = (s[2]//py)*py, (s[3]//px)*px
|
||||
ret = np.zeros(s, dtype=grad_output.dtype)
|
||||
for Y in range(py):
|
||||
@@ -155,15 +155,15 @@ register('max_pool2d', MaxPool2D)
|
||||
|
||||
class AvgPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, pool_size=(2, 2)):
|
||||
stack = stack_for_pool(x, *pool_size)
|
||||
def forward(ctx, x, kernel_size=(2, 2)):
|
||||
stack = stack_for_pool(x, *kernel_size)
|
||||
ctx.save_for_backward(x.shape)
|
||||
return np.mean(stack, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
s, = ctx.saved_tensors
|
||||
py, px = ctx.pool_size
|
||||
py, px = ctx.kernel_size
|
||||
my, mx = (s[2]//py)*py, (s[3]//px)*px
|
||||
ret = np.zeros(s, dtype=grad_output.dtype)
|
||||
for Y in range(py):
|
||||
|
||||
Reference in New Issue
Block a user