mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
conv isn't fast yet
This commit is contained in:
@@ -33,37 +33,50 @@ def fetch_mnist():
|
||||
# write them fast and the convs will be fast?
|
||||
|
||||
@lru_cache
|
||||
def get_im2col_indexes(oy, ox, cin, H, W):
|
||||
def get_im2col_index(oy, ox, cin, H, W):
|
||||
idxc = np.tile(np.arange(cin).repeat(H*W), oy*ox)
|
||||
idxy = np.tile(np.arange(H).repeat(W), oy*ox*cin) + np.arange(oy).repeat(ox*cin*H*W)
|
||||
idxx = np.tile(np.arange(W), oy*ox*cin*H) + np.tile(np.arange(ox), oy).repeat(cin*H*W)
|
||||
return idxc, idxy, idxx
|
||||
|
||||
# why return 3 index when we can return 1?
|
||||
OY, OX = oy+(H-1), ox+(W-1)
|
||||
idx = idxc * OY * OX + idxy * OX + idxx
|
||||
return idx
|
||||
|
||||
def im2col(x, H, W):
|
||||
bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1)
|
||||
|
||||
ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W)
|
||||
tx = x[:, ic, iy, ix]
|
||||
idx = get_im2col_index(oy, ox, cin, H, W)
|
||||
tx = x.reshape(bs, -1)[:, idx]
|
||||
|
||||
"""
|
||||
# this is slower
|
||||
tx = np.empty((bs, oy, ox, cin*W*H), dtype=x.dtype)
|
||||
for Y in range(oy):
|
||||
for X in range(ox):
|
||||
tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
|
||||
"""
|
||||
|
||||
return tx.reshape(-1, cin*W*H)
|
||||
|
||||
def col2im(tx, H, W, OY, OX):
|
||||
oy, ox = OY-(H-1), OX-(W-1)
|
||||
bs = tx.shape[0] // (oy * ox)
|
||||
cin = tx.shape[1] // (H * W)
|
||||
x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype)
|
||||
|
||||
"""
|
||||
# col2im is just im2col in reverse
|
||||
tx = tx.reshape(bs, -1)
|
||||
ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W)
|
||||
np.add.at(x, (slice(None), ic, iy, ix), tx)
|
||||
x = np.zeros((bs, cin*OY*OX), dtype=tx.dtype)
|
||||
idx = get_im2col_index(oy, ox, cin, H, W)
|
||||
np.add.at(x, (slice(None), idx), tx.reshape(bs, -1))
|
||||
"""
|
||||
|
||||
# sadly, this is faster
|
||||
x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype)
|
||||
tx = tx.reshape(bs, oy, ox, cin, H, W)
|
||||
for Y in range(oy):
|
||||
for X in range(ox):
|
||||
x[:, :, Y:Y+H, X:X+W] += tx[:, Y, X]
|
||||
|
||||
return x
|
||||
return x.reshape(bs, cin, OY, OX)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user