mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
if you want fast convs, revert this
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Function, register
|
||||
from tinygrad.utils import im2col, col2im
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
@@ -123,53 +122,8 @@ class Conv2D(Function):
|
||||
dw += gg.T.dot(tx).reshape(dw.shape)
|
||||
dx[:, :, Y:Y+H, X:X+W] += gg.dot(tw).reshape(dx.shape[0], dx.shape[1], H, W)
|
||||
return dx, dw
|
||||
#register('conv2d', Conv2D)
|
||||
register('conv2d', Conv2D)
|
||||
|
||||
class FastConv2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w):
|
||||
cout,cin,H,W = w.shape
|
||||
tw = w.reshape(cout, -1).T
|
||||
bs,oy,ox = x.shape[0], x.shape[2]-(H-1), x.shape[3]-(W-1)
|
||||
|
||||
# im2col
|
||||
tx = im2col(x, H, W)
|
||||
|
||||
# save the im2col output (OMG it's bigger!)
|
||||
ctx.save_for_backward(tx, w)
|
||||
|
||||
# now the conv is a GEMM
|
||||
ret = tx.dot(tw).reshape(bs, oy, ox, cout)
|
||||
|
||||
# order correctly
|
||||
return np.moveaxis(ret, [0,1,2,3], [0,2,3,1])
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
bs,_,oy,ox = grad_output.shape
|
||||
tx, w = ctx.saved_tensors
|
||||
cout,cin,H,W = w.shape
|
||||
# grad_output.shape = (bs, cout, oy, ox)
|
||||
# tx.shape = (bs*oy*ox*cin, H*W)
|
||||
tw = w.reshape(w.shape[0], -1)
|
||||
|
||||
# reshape correctly
|
||||
ggt = np.moveaxis(grad_output, [0,1,2,3], [1,0,2,3]).reshape(cout, -1)
|
||||
|
||||
# dw is easy
|
||||
dw = ggt.dot(tx).reshape(w.shape)
|
||||
|
||||
# dx is harder
|
||||
dxi = ggt.T.dot(tw)
|
||||
|
||||
# if we im2col on the forward, we col2im on the backward
|
||||
# dxi should be (bs, oy, ox, cin, H, W)
|
||||
dx = col2im(dxi, H, W, oy+(H-1), ox+(W-1))
|
||||
|
||||
return dx, dw
|
||||
register('conv2d', FastConv2D)
|
||||
|
||||
# TODO: make this parameterizable
|
||||
class MaxPool2x2(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import numpy as np
|
||||
from functools import lru_cache
|
||||
|
||||
def mask_like(like, mask_inx, mask_value = 1.0):
|
||||
mask = np.zeros_like(like).reshape(-1)
|
||||
@@ -28,76 +27,3 @@ def fetch_mnist():
|
||||
Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:]
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
|
||||
|
||||
# these are matlab functions used to speed up convs
|
||||
# write them fast and the convs will be fast?
|
||||
|
||||
@lru_cache
|
||||
def get_im2col_index(oy, ox, cin, H, W):
|
||||
idxc = np.tile(np.arange(cin).repeat(H*W), oy*ox)
|
||||
idxy = np.tile(np.arange(H).repeat(W), oy*ox*cin) + np.arange(oy).repeat(ox*cin*H*W)
|
||||
idxx = np.tile(np.arange(W), oy*ox*cin*H) + np.tile(np.arange(ox), oy).repeat(cin*H*W)
|
||||
|
||||
# why return 3 index when we can return 1?
|
||||
OY, OX = oy+(H-1), ox+(W-1)
|
||||
idx = idxc * OY * OX + idxy * OX + idxx
|
||||
return idx
|
||||
|
||||
@lru_cache
|
||||
def swizzle_col2im_index(oy, ox, cin, H, W):
|
||||
idx = get_im2col_index(oy, ox, cin, H, W)
|
||||
ridx = np.zeros((np.max(idx)+1, H*W), dtype=idx.dtype)-1
|
||||
for i,x in enumerate(idx):
|
||||
for j in range(H*W):
|
||||
if ridx[x,j] == -1:
|
||||
ridx[x,j] = i
|
||||
break
|
||||
return ridx
|
||||
|
||||
def im2col(x, H, W):
|
||||
bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1)
|
||||
|
||||
idx = get_im2col_index(oy, ox, cin, H, W)
|
||||
tx = x.reshape(bs, -1)[:, idx]
|
||||
|
||||
"""
|
||||
# this is slower
|
||||
tx = np.empty((bs, oy, ox, cin*W*H), dtype=x.dtype)
|
||||
for Y in range(oy):
|
||||
for X in range(ox):
|
||||
tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
|
||||
"""
|
||||
|
||||
# all the time is spent here
|
||||
tx = tx.ravel()
|
||||
|
||||
return tx.reshape(-1, cin*W*H)
|
||||
|
||||
def col2im(tx, H, W, OY, OX):
|
||||
oy, ox = OY-(H-1), OX-(W-1)
|
||||
bs = tx.shape[0] // (oy * ox)
|
||||
cin = tx.shape[1] // (H * W)
|
||||
|
||||
ridx = swizzle_col2im_index(oy, ox, cin, H, W)
|
||||
# -1 has to be 0s
|
||||
x = np.pad(tx.reshape(bs, -1), ((0,0),(0,1)))[:, ridx].sum(axis=2)
|
||||
|
||||
"""
|
||||
# col2im is just im2col in reverse, but np.add.at is SLOW
|
||||
idx = get_im2col_index(oy, ox, cin, H, W)
|
||||
x = np.zeros((bs, cin*OY*OX), dtype=tx.dtype)
|
||||
idx = get_im2col_index(oy, ox, cin, H, W)
|
||||
np.add.at(x, (slice(None), idx), tx.reshape(bs, -1))
|
||||
"""
|
||||
|
||||
"""
|
||||
# sadly, this is faster
|
||||
x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype)
|
||||
tx = tx.reshape(bs, oy, ox, cin, H, W)
|
||||
for Y in range(oy):
|
||||
for X in range(ox):
|
||||
x[:, :, Y:Y+H, X:X+W] += tx[:, Y, X]
|
||||
"""
|
||||
|
||||
return x.reshape(bs, cin, OY, OX)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user