always keep batch size out front

This commit is contained in:
George Hotz
2020-10-25 08:14:07 -07:00
parent b91fd3afad
commit 935f5ddaaa
4 changed files with 25 additions and 11 deletions

View File

@@ -1,22 +1,35 @@
#!/usr/bin/env python
import builtins
try:
import line_profiler
prof = line_profiler.LineProfiler()
builtins.__dict__['profile'] = prof
# add @profile decorator to probe
except ImportError:
pass
import cProfile
import unittest
from tinygrad.tensor import Tensor
def profile_conv(bs, chans, conv, cnt=100):
def profile_conv(bs, chans, conv, cnt=10):
img = Tensor.zeros(bs, 1, 28, 28)
conv = Tensor.randn(chans, 1, conv, conv)
for i in range(cnt):
out = img.conv2d(conv)
g = out.mean().backward()
class TestConvSpeed(unittest.TestCase):
def test_forward_3x3(self):
def test_forward_backward_3x3(self):
pr = cProfile.Profile()
pr.enable()
profile_conv(128, 16, 3)
pr.disable()
pr.print_stats(sort='time')
if prof is not None:
prof.print_stats()
if __name__ == '__main__':
unittest.main()

View File

@@ -78,8 +78,8 @@ class TestOps(unittest.TestCase):
out.mean().backward()
ret.mean().backward()
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-5)
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-5)
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-7)
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-7)
def test_maxpool2x2(self):
x = torch.randn((5,2,10,8), requires_grad=True)

View File

@@ -140,10 +140,10 @@ class FastConv2D(Function):
ctx.save_for_backward(tx, w)
# now the conv is a GEMM
ret = tx.dot(tw).reshape(oy, ox, bs, cout)
ret = tx.dot(tw).reshape(bs, oy, ox, cout)
# order correctly
return np.moveaxis(ret, [0,1,2,3], [2,3,0,1])
return np.moveaxis(ret, [0,1,2,3], [0,2,3,1])
@staticmethod
def backward(ctx, grad_output):
@@ -153,7 +153,7 @@ class FastConv2D(Function):
tw = w.reshape(w.shape[0], -1)
# order correctly
gg = np.moveaxis(grad_output, [0,1,2,3], [2,3,0,1]).reshape(-1, cout)
gg = np.moveaxis(grad_output, [0,1,2,3], [0,2,3,1]).reshape(-1, cout)
# dw is easy
dw = gg.T.dot(tx).reshape(w.shape)

View File

@@ -34,20 +34,21 @@ def fetch_mnist():
def im2col(x, H, W):
bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1)
tx = np.empty((oy, ox, bs, cin*W*H), dtype=x.dtype)
tx = np.empty((bs, oy, ox, cin*W*H), dtype=x.dtype)
for Y in range(oy):
for X in range(ox):
tx[Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
return tx.reshape(-1, cin*W*H)
def col2im(tx, H, W, OY, OX):
oy, ox = OY-(H-1), OX-(W-1)
bs = tx.shape[0] // (oy * ox)
cin = tx.shape[1] // (H * W)
tx = tx.reshape(oy, ox, bs, cin, H, W)
tx = tx.reshape(bs, oy, ox, cin, H, W)
x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype)
for Y in range(oy):
for X in range(ox):
x[:, :, Y:Y+H, X:X+W] += tx[Y, X]
x[:, :, Y:Y+H, X:X+W] += tx[:, Y, X]
return x