fast im2col

This commit is contained in:
George Hotz
2020-10-25 11:49:35 -07:00
parent c9968756d1
commit 67506eb6ba
2 changed files with 11 additions and 7 deletions

View File

@@ -67,7 +67,7 @@ class TestTinygrad(unittest.TestCase):
class TestOps(unittest.TestCase):
def test_conv2d(self):
x = torch.randn((5,2,10,7), requires_grad=True)
w = torch.randn((4,2,3,3), requires_grad=True)
w = torch.randn((4,2,3,2), requires_grad=True)
xt = Tensor(x.detach().numpy())
wt = Tensor(w.detach().numpy())

View File

@@ -1,4 +1,5 @@
import numpy as np
from functools import lru_cache
def mask_like(like, mask_inx, mask_value = 1.0):
mask = np.zeros_like(like).reshape(-1)
@@ -31,14 +32,17 @@ def fetch_mnist():
# these are matlab functions used to speed up convs
# write them fast and the convs will be fast?
@lru_cache
def get_im2col_indexes(oy, ox, cin, H, W):
idxc = np.tile(np.arange(cin).repeat(H*W), oy*ox)
idxy = np.tile(np.arange(H).repeat(W), oy*ox*cin) + np.arange(oy).repeat(ox*cin*H*W)
idxx = np.tile(np.arange(W), oy*ox*cin*H) + np.tile(np.arange(ox), oy).repeat(cin*H*W)
return idxc, idxy, idxx
def im2col(x, H, W):
bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1)
# TODO: use something like np.take for speed
tx = np.empty((bs, oy, ox, cin*W*H), dtype=x.dtype)
for Y in range(oy):
for X in range(ox):
tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W)
tx = x[:, ic, iy, ix]
return tx.reshape(-1, cin*W*H)
def col2im(tx, H, W, OY, OX):