From 0f02084805e1e24df85591be55af2d2108ec9d2f Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 25 Oct 2020 12:13:58 -0700 Subject: [PATCH] conv isn't fast yet --- tinygrad/utils.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/tinygrad/utils.py b/tinygrad/utils.py index b926f2fcdd..330c64ee43 100644 --- a/tinygrad/utils.py +++ b/tinygrad/utils.py @@ -33,37 +33,50 @@ def fetch_mnist(): # write them fast and the convs will be fast? @lru_cache -def get_im2col_indexes(oy, ox, cin, H, W): +def get_im2col_index(oy, ox, cin, H, W): idxc = np.tile(np.arange(cin).repeat(H*W), oy*ox) idxy = np.tile(np.arange(H).repeat(W), oy*ox*cin) + np.arange(oy).repeat(ox*cin*H*W) idxx = np.tile(np.arange(W), oy*ox*cin*H) + np.tile(np.arange(ox), oy).repeat(cin*H*W) - return idxc, idxy, idxx + + # why return 3 index when we can return 1? + OY, OX = oy+(H-1), ox+(W-1) + idx = idxc * OY * OX + idxy * OX + idxx + return idx def im2col(x, H, W): bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1) - ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W) - tx = x[:, ic, iy, ix] + idx = get_im2col_index(oy, ox, cin, H, W) + tx = x.reshape(bs, -1)[:, idx] + + """ + # this is slower + tx = np.empty((bs, oy, ox, cin*W*H), dtype=x.dtype) + for Y in range(oy): + for X in range(ox): + tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1) + """ + return tx.reshape(-1, cin*W*H) def col2im(tx, H, W, OY, OX): oy, ox = OY-(H-1), OX-(W-1) bs = tx.shape[0] // (oy * ox) cin = tx.shape[1] // (H * W) - x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype) """ # col2im is just im2col in reverse - tx = tx.reshape(bs, -1) - ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W) - np.add.at(x, (slice(None), ic, iy, ix), tx) + x = np.zeros((bs, cin*OY*OX), dtype=tx.dtype) + idx = get_im2col_index(oy, ox, cin, H, W) + np.add.at(x, (slice(None), idx), tx.reshape(bs, -1)) """ # sadly, this is faster + x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype) tx = tx.reshape(bs, oy, ox, cin, H, W) for Y in range(oy): for X in range(ox): x[:, :, Y:Y+H, X:X+W] += tx[:, Y, X] - return x + return x.reshape(bs, cin, OY, OX)