mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
always keep batch size out front
This commit is contained in:
@@ -1,22 +1,35 @@
|
||||
#!/usr/bin/env python
|
||||
import builtins
|
||||
try:
|
||||
import line_profiler
|
||||
prof = line_profiler.LineProfiler()
|
||||
builtins.__dict__['profile'] = prof
|
||||
# add @profile decorator to probe
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import cProfile
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def profile_conv(bs, chans, conv, cnt=100):
|
||||
def profile_conv(bs, chans, conv, cnt=10):
|
||||
img = Tensor.zeros(bs, 1, 28, 28)
|
||||
conv = Tensor.randn(chans, 1, conv, conv)
|
||||
for i in range(cnt):
|
||||
out = img.conv2d(conv)
|
||||
g = out.mean().backward()
|
||||
|
||||
class TestConvSpeed(unittest.TestCase):
|
||||
def test_forward_3x3(self):
|
||||
def test_forward_backward_3x3(self):
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
profile_conv(128, 16, 3)
|
||||
pr.disable()
|
||||
pr.print_stats(sort='time')
|
||||
|
||||
if prof is not None:
|
||||
prof.print_stats()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
||||
@@ -78,8 +78,8 @@ class TestOps(unittest.TestCase):
|
||||
out.mean().backward()
|
||||
ret.mean().backward()
|
||||
|
||||
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-5)
|
||||
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-5)
|
||||
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-7)
|
||||
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-7)
|
||||
|
||||
def test_maxpool2x2(self):
|
||||
x = torch.randn((5,2,10,8), requires_grad=True)
|
||||
|
||||
@@ -140,10 +140,10 @@ class FastConv2D(Function):
|
||||
ctx.save_for_backward(tx, w)
|
||||
|
||||
# now the conv is a GEMM
|
||||
ret = tx.dot(tw).reshape(oy, ox, bs, cout)
|
||||
ret = tx.dot(tw).reshape(bs, oy, ox, cout)
|
||||
|
||||
# order correctly
|
||||
return np.moveaxis(ret, [0,1,2,3], [2,3,0,1])
|
||||
return np.moveaxis(ret, [0,1,2,3], [0,2,3,1])
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
@@ -153,7 +153,7 @@ class FastConv2D(Function):
|
||||
tw = w.reshape(w.shape[0], -1)
|
||||
|
||||
# order correctly
|
||||
gg = np.moveaxis(grad_output, [0,1,2,3], [2,3,0,1]).reshape(-1, cout)
|
||||
gg = np.moveaxis(grad_output, [0,1,2,3], [0,2,3,1]).reshape(-1, cout)
|
||||
|
||||
# dw is easy
|
||||
dw = gg.T.dot(tx).reshape(w.shape)
|
||||
|
||||
@@ -34,20 +34,21 @@ def fetch_mnist():
|
||||
def im2col(x, H, W):
|
||||
bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1)
|
||||
|
||||
tx = np.empty((oy, ox, bs, cin*W*H), dtype=x.dtype)
|
||||
tx = np.empty((bs, oy, ox, cin*W*H), dtype=x.dtype)
|
||||
for Y in range(oy):
|
||||
for X in range(ox):
|
||||
tx[Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
|
||||
tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
|
||||
|
||||
return tx.reshape(-1, cin*W*H)
|
||||
|
||||
def col2im(tx, H, W, OY, OX):
|
||||
oy, ox = OY-(H-1), OX-(W-1)
|
||||
bs = tx.shape[0] // (oy * ox)
|
||||
cin = tx.shape[1] // (H * W)
|
||||
tx = tx.reshape(oy, ox, bs, cin, H, W)
|
||||
tx = tx.reshape(bs, oy, ox, cin, H, W)
|
||||
|
||||
x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype)
|
||||
for Y in range(oy):
|
||||
for X in range(ox):
|
||||
x[:, :, Y:Y+H, X:X+W] += tx[Y, X]
|
||||
x[:, :, Y:Y+H, X:X+W] += tx[:, Y, X]
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user