always keep batch size out front

This commit is contained in:
George Hotz
2020-10-25 08:14:07 -07:00
parent b91fd3afad
commit 935f5ddaaa
4 changed files with 25 additions and 11 deletions

View File

@@ -1,22 +1,35 @@
#!/usr/bin/env python #!/usr/bin/env python
import builtins
try:
import line_profiler
prof = line_profiler.LineProfiler()
builtins.__dict__['profile'] = prof
# add @profile decorator to probe
except ImportError:
pass
import cProfile import cProfile
import unittest import unittest
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
def profile_conv(bs, chans, conv, cnt=100): def profile_conv(bs, chans, conv, cnt=10):
img = Tensor.zeros(bs, 1, 28, 28) img = Tensor.zeros(bs, 1, 28, 28)
conv = Tensor.randn(chans, 1, conv, conv) conv = Tensor.randn(chans, 1, conv, conv)
for i in range(cnt): for i in range(cnt):
out = img.conv2d(conv) out = img.conv2d(conv)
g = out.mean().backward()
class TestConvSpeed(unittest.TestCase): class TestConvSpeed(unittest.TestCase):
def test_forward_3x3(self): def test_forward_backward_3x3(self):
pr = cProfile.Profile() pr = cProfile.Profile()
pr.enable() pr.enable()
profile_conv(128, 16, 3) profile_conv(128, 16, 3)
pr.disable() pr.disable()
pr.print_stats(sort='time') pr.print_stats(sort='time')
if prof is not None:
prof.print_stats()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -78,8 +78,8 @@ class TestOps(unittest.TestCase):
out.mean().backward() out.mean().backward()
ret.mean().backward() ret.mean().backward()
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-5) np.testing.assert_allclose(w.grad, wt.grad, atol=1e-7)
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-5) np.testing.assert_allclose(x.grad, xt.grad, atol=1e-7)
def test_maxpool2x2(self): def test_maxpool2x2(self):
x = torch.randn((5,2,10,8), requires_grad=True) x = torch.randn((5,2,10,8), requires_grad=True)

View File

@@ -140,10 +140,10 @@ class FastConv2D(Function):
ctx.save_for_backward(tx, w) ctx.save_for_backward(tx, w)
# now the conv is a GEMM # now the conv is a GEMM
ret = tx.dot(tw).reshape(oy, ox, bs, cout) ret = tx.dot(tw).reshape(bs, oy, ox, cout)
# order correctly # order correctly
return np.moveaxis(ret, [0,1,2,3], [2,3,0,1]) return np.moveaxis(ret, [0,1,2,3], [0,2,3,1])
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
@@ -153,7 +153,7 @@ class FastConv2D(Function):
tw = w.reshape(w.shape[0], -1) tw = w.reshape(w.shape[0], -1)
# order correctly # order correctly
gg = np.moveaxis(grad_output, [0,1,2,3], [2,3,0,1]).reshape(-1, cout) gg = np.moveaxis(grad_output, [0,1,2,3], [0,2,3,1]).reshape(-1, cout)
# dw is easy # dw is easy
dw = gg.T.dot(tx).reshape(w.shape) dw = gg.T.dot(tx).reshape(w.shape)

View File

@@ -34,20 +34,21 @@ def fetch_mnist():
def im2col(x, H, W): def im2col(x, H, W):
bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1) bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1)
tx = np.empty((oy, ox, bs, cin*W*H), dtype=x.dtype) tx = np.empty((bs, oy, ox, cin*W*H), dtype=x.dtype)
for Y in range(oy): for Y in range(oy):
for X in range(ox): for X in range(ox):
tx[Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1) tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
return tx.reshape(-1, cin*W*H) return tx.reshape(-1, cin*W*H)
def col2im(tx, H, W, OY, OX): def col2im(tx, H, W, OY, OX):
oy, ox = OY-(H-1), OX-(W-1) oy, ox = OY-(H-1), OX-(W-1)
bs = tx.shape[0] // (oy * ox) bs = tx.shape[0] // (oy * ox)
cin = tx.shape[1] // (H * W) cin = tx.shape[1] // (H * W)
tx = tx.reshape(oy, ox, bs, cin, H, W) tx = tx.reshape(bs, oy, ox, cin, H, W)
x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype) x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype)
for Y in range(oy): for Y in range(oy):
for X in range(ox): for X in range(ox):
x[:, :, Y:Y+H, X:X+W] += tx[Y, X] x[:, :, Y:Y+H, X:X+W] += tx[:, Y, X]
return x return x