mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
fast im2col
This commit is contained in:
@@ -67,7 +67,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
class TestOps(unittest.TestCase):
|
||||
def test_conv2d(self):
|
||||
x = torch.randn((5,2,10,7), requires_grad=True)
|
||||
w = torch.randn((4,2,3,3), requires_grad=True)
|
||||
w = torch.randn((4,2,3,2), requires_grad=True)
|
||||
xt = Tensor(x.detach().numpy())
|
||||
wt = Tensor(w.detach().numpy())
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
from functools import lru_cache
|
||||
|
||||
def mask_like(like, mask_inx, mask_value = 1.0):
|
||||
mask = np.zeros_like(like).reshape(-1)
|
||||
@@ -31,14 +32,17 @@ def fetch_mnist():
|
||||
# these are matlab functions used to speed up convs
|
||||
# write them fast and the convs will be fast?
|
||||
|
||||
@lru_cache
|
||||
def get_im2col_indexes(oy, ox, cin, H, W):
|
||||
idxc = np.tile(np.arange(cin).repeat(H*W), oy*ox)
|
||||
idxy = np.tile(np.arange(H).repeat(W), oy*ox*cin) + np.arange(oy).repeat(ox*cin*H*W)
|
||||
idxx = np.tile(np.arange(W), oy*ox*cin*H) + np.tile(np.arange(ox), oy).repeat(cin*H*W)
|
||||
return idxc, idxy, idxx
|
||||
|
||||
def im2col(x, H, W):
|
||||
bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1)
|
||||
|
||||
# TODO: use something like np.take for speed
|
||||
tx = np.empty((bs, oy, ox, cin*W*H), dtype=x.dtype)
|
||||
for Y in range(oy):
|
||||
for X in range(ox):
|
||||
tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
|
||||
ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W)
|
||||
tx = x[:, ic, iy, ix]
|
||||
return tx.reshape(-1, cin*W*H)
|
||||
|
||||
def col2im(tx, H, W, OY, OX):
|
||||
|
||||
Reference in New Issue
Block a user