diff --git a/test/test_conv_speed.py b/test/test_conv_speed.py index 7867ce5675..30fb15b73d 100644 --- a/test/test_conv_speed.py +++ b/test/test_conv_speed.py @@ -32,6 +32,10 @@ def profile_conv(bs, chans, conv, cnt=10): class TestConvSpeed(unittest.TestCase): def test_forward_backward_3x3(self): + # warmup + profile_conv(128, 16, 3, cnt=1) + + # profile pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6) pr.enable() fpt, bpt = profile_conv(128, 16, 3) diff --git a/tinygrad/utils.py b/tinygrad/utils.py index 330c64ee43..93e772a84a 100644 --- a/tinygrad/utils.py +++ b/tinygrad/utils.py @@ -43,6 +43,17 @@ def get_im2col_index(oy, ox, cin, H, W): idx = idxc * OY * OX + idxy * OX + idxx return idx +@lru_cache +def swizzle_col2im_index(oy, ox, cin, H, W): + idx = get_im2col_index(oy, ox, cin, H, W) + ridx = np.zeros((np.max(idx)+1, H*W), dtype=idx.dtype)-1 + for i,x in enumerate(idx): + for j in range(H*W): + if ridx[x,j] == -1: + ridx[x,j] = i + break + return ridx + def im2col(x, H, W): bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1) @@ -64,19 +75,26 @@ def col2im(tx, H, W, OY, OX): bs = tx.shape[0] // (oy * ox) cin = tx.shape[1] // (H * W) + ridx = swizzle_col2im_index(oy, ox, cin, H, W) + # -1 has to be 0s + x = np.pad(tx.reshape(bs, -1), ((0,0),(0,1)))[:, ridx].sum(axis=2) + """ - # col2im is just im2col in reverse + # col2im is just im2col in reverse, but np.add.at is SLOW + idx = get_im2col_index(oy, ox, cin, H, W) x = np.zeros((bs, cin*OY*OX), dtype=tx.dtype) idx = get_im2col_index(oy, ox, cin, H, W) np.add.at(x, (slice(None), idx), tx.reshape(bs, -1)) """ + """ # sadly, this is faster x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype) tx = tx.reshape(bs, oy, ox, cin, H, W) for Y in range(oy): for X in range(ox): x[:, :, Y:Y+H, X:X+W] += tx[:, Y, X] + """ return x.reshape(bs, cin, OY, OX)