mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
always keep batch size out front
This commit is contained in:
@@ -1,22 +1,35 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
import builtins
|
||||||
|
try:
|
||||||
|
import line_profiler
|
||||||
|
prof = line_profiler.LineProfiler()
|
||||||
|
builtins.__dict__['profile'] = prof
|
||||||
|
# add @profile decorator to probe
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
import cProfile
|
import cProfile
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
def profile_conv(bs, chans, conv, cnt=100):
|
def profile_conv(bs, chans, conv, cnt=10):
|
||||||
img = Tensor.zeros(bs, 1, 28, 28)
|
img = Tensor.zeros(bs, 1, 28, 28)
|
||||||
conv = Tensor.randn(chans, 1, conv, conv)
|
conv = Tensor.randn(chans, 1, conv, conv)
|
||||||
for i in range(cnt):
|
for i in range(cnt):
|
||||||
out = img.conv2d(conv)
|
out = img.conv2d(conv)
|
||||||
|
g = out.mean().backward()
|
||||||
|
|
||||||
class TestConvSpeed(unittest.TestCase):
|
class TestConvSpeed(unittest.TestCase):
|
||||||
def test_forward_3x3(self):
|
def test_forward_backward_3x3(self):
|
||||||
pr = cProfile.Profile()
|
pr = cProfile.Profile()
|
||||||
pr.enable()
|
pr.enable()
|
||||||
profile_conv(128, 16, 3)
|
profile_conv(128, 16, 3)
|
||||||
pr.disable()
|
pr.disable()
|
||||||
pr.print_stats(sort='time')
|
pr.print_stats(sort='time')
|
||||||
|
|
||||||
|
if prof is not None:
|
||||||
|
prof.print_stats()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
||||||
|
|||||||
@@ -78,8 +78,8 @@ class TestOps(unittest.TestCase):
|
|||||||
out.mean().backward()
|
out.mean().backward()
|
||||||
ret.mean().backward()
|
ret.mean().backward()
|
||||||
|
|
||||||
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-5)
|
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-7)
|
||||||
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-5)
|
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-7)
|
||||||
|
|
||||||
def test_maxpool2x2(self):
|
def test_maxpool2x2(self):
|
||||||
x = torch.randn((5,2,10,8), requires_grad=True)
|
x = torch.randn((5,2,10,8), requires_grad=True)
|
||||||
|
|||||||
@@ -140,10 +140,10 @@ class FastConv2D(Function):
|
|||||||
ctx.save_for_backward(tx, w)
|
ctx.save_for_backward(tx, w)
|
||||||
|
|
||||||
# now the conv is a GEMM
|
# now the conv is a GEMM
|
||||||
ret = tx.dot(tw).reshape(oy, ox, bs, cout)
|
ret = tx.dot(tw).reshape(bs, oy, ox, cout)
|
||||||
|
|
||||||
# order correctly
|
# order correctly
|
||||||
return np.moveaxis(ret, [0,1,2,3], [2,3,0,1])
|
return np.moveaxis(ret, [0,1,2,3], [0,2,3,1])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
@@ -153,7 +153,7 @@ class FastConv2D(Function):
|
|||||||
tw = w.reshape(w.shape[0], -1)
|
tw = w.reshape(w.shape[0], -1)
|
||||||
|
|
||||||
# order correctly
|
# order correctly
|
||||||
gg = np.moveaxis(grad_output, [0,1,2,3], [2,3,0,1]).reshape(-1, cout)
|
gg = np.moveaxis(grad_output, [0,1,2,3], [0,2,3,1]).reshape(-1, cout)
|
||||||
|
|
||||||
# dw is easy
|
# dw is easy
|
||||||
dw = gg.T.dot(tx).reshape(w.shape)
|
dw = gg.T.dot(tx).reshape(w.shape)
|
||||||
|
|||||||
@@ -34,20 +34,21 @@ def fetch_mnist():
|
|||||||
def im2col(x, H, W):
|
def im2col(x, H, W):
|
||||||
bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1)
|
bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1)
|
||||||
|
|
||||||
tx = np.empty((oy, ox, bs, cin*W*H), dtype=x.dtype)
|
tx = np.empty((bs, oy, ox, cin*W*H), dtype=x.dtype)
|
||||||
for Y in range(oy):
|
for Y in range(oy):
|
||||||
for X in range(ox):
|
for X in range(ox):
|
||||||
tx[Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
|
tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
|
||||||
|
|
||||||
return tx.reshape(-1, cin*W*H)
|
return tx.reshape(-1, cin*W*H)
|
||||||
|
|
||||||
def col2im(tx, H, W, OY, OX):
|
def col2im(tx, H, W, OY, OX):
|
||||||
oy, ox = OY-(H-1), OX-(W-1)
|
oy, ox = OY-(H-1), OX-(W-1)
|
||||||
bs = tx.shape[0] // (oy * ox)
|
bs = tx.shape[0] // (oy * ox)
|
||||||
cin = tx.shape[1] // (H * W)
|
cin = tx.shape[1] // (H * W)
|
||||||
tx = tx.reshape(oy, ox, bs, cin, H, W)
|
tx = tx.reshape(bs, oy, ox, cin, H, W)
|
||||||
|
|
||||||
x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype)
|
x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype)
|
||||||
for Y in range(oy):
|
for Y in range(oy):
|
||||||
for X in range(ox):
|
for X in range(ox):
|
||||||
x[:, :, Y:Y+H, X:X+W] += tx[Y, X]
|
x[:, :, Y:Y+H, X:X+W] += tx[:, Y, X]
|
||||||
return x
|
return x
|
||||||
|
|||||||
Reference in New Issue
Block a user