From 5d1373c71b240e2d4a9e0d65d6de506c394cc92e Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 25 Oct 2020 17:10:05 -0700 Subject: [PATCH] if you want fast convs, revert this --- tinygrad/ops.py | 48 +----------------------------- tinygrad/utils.py | 74 ----------------------------------------------- 2 files changed, 1 insertion(+), 121 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 25e219d623..b4cb8b1a06 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,6 +1,5 @@ import numpy as np from tinygrad.tensor import Function, register -from tinygrad.utils import im2col, col2im class Reshape(Function): @staticmethod @@ -123,53 +122,8 @@ class Conv2D(Function): dw += gg.T.dot(tx).reshape(dw.shape) dx[:, :, Y:Y+H, X:X+W] += gg.dot(tw).reshape(dx.shape[0], dx.shape[1], H, W) return dx, dw -#register('conv2d', Conv2D) +register('conv2d', Conv2D) -class FastConv2D(Function): - @staticmethod - def forward(ctx, x, w): - cout,cin,H,W = w.shape - tw = w.reshape(cout, -1).T - bs,oy,ox = x.shape[0], x.shape[2]-(H-1), x.shape[3]-(W-1) - - # im2col - tx = im2col(x, H, W) - - # save the im2col output (OMG it's bigger!) - ctx.save_for_backward(tx, w) - - # now the conv is a GEMM - ret = tx.dot(tw).reshape(bs, oy, ox, cout) - - # order correctly - return np.moveaxis(ret, [0,1,2,3], [0,2,3,1]) - - @staticmethod - def backward(ctx, grad_output): - bs,_,oy,ox = grad_output.shape - tx, w = ctx.saved_tensors - cout,cin,H,W = w.shape - # grad_output.shape = (bs, cout, oy, ox) - # tx.shape = (bs*oy*ox*cin, H*W) - tw = w.reshape(w.shape[0], -1) - - # reshape correctly - ggt = np.moveaxis(grad_output, [0,1,2,3], [1,0,2,3]).reshape(cout, -1) - - # dw is easy - dw = ggt.dot(tx).reshape(w.shape) - - # dx is harder - dxi = ggt.T.dot(tw) - - # if we im2col on the forward, we col2im on the backward - # dxi should be (bs, oy, ox, cin, H, W) - dx = col2im(dxi, H, W, oy+(H-1), ox+(W-1)) - - return dx, dw -register('conv2d', FastConv2D) - -# TODO: make this parameterizable class MaxPool2x2(Function): @staticmethod def forward(ctx, x): diff --git a/tinygrad/utils.py b/tinygrad/utils.py index 872653c9eb..b90bfcde31 100644 --- a/tinygrad/utils.py +++ b/tinygrad/utils.py @@ -1,5 +1,4 @@ import numpy as np -from functools import lru_cache def mask_like(like, mask_inx, mask_value = 1.0): mask = np.zeros_like(like).reshape(-1) @@ -28,76 +27,3 @@ def fetch_mnist(): Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:] return X_train, Y_train, X_test, Y_test - -# these are matlab functions used to speed up convs -# write them fast and the convs will be fast? - -@lru_cache -def get_im2col_index(oy, ox, cin, H, W): - idxc = np.tile(np.arange(cin).repeat(H*W), oy*ox) - idxy = np.tile(np.arange(H).repeat(W), oy*ox*cin) + np.arange(oy).repeat(ox*cin*H*W) - idxx = np.tile(np.arange(W), oy*ox*cin*H) + np.tile(np.arange(ox), oy).repeat(cin*H*W) - - # why return 3 index when we can return 1? - OY, OX = oy+(H-1), ox+(W-1) - idx = idxc * OY * OX + idxy * OX + idxx - return idx - -@lru_cache -def swizzle_col2im_index(oy, ox, cin, H, W): - idx = get_im2col_index(oy, ox, cin, H, W) - ridx = np.zeros((np.max(idx)+1, H*W), dtype=idx.dtype)-1 - for i,x in enumerate(idx): - for j in range(H*W): - if ridx[x,j] == -1: - ridx[x,j] = i - break - return ridx - -def im2col(x, H, W): - bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1) - - idx = get_im2col_index(oy, ox, cin, H, W) - tx = x.reshape(bs, -1)[:, idx] - - """ - # this is slower - tx = np.empty((bs, oy, ox, cin*W*H), dtype=x.dtype) - for Y in range(oy): - for X in range(ox): - tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1) - """ - - # all the time is spent here - tx = tx.ravel() - - return tx.reshape(-1, cin*W*H) - -def col2im(tx, H, W, OY, OX): - oy, ox = OY-(H-1), OX-(W-1) - bs = tx.shape[0] // (oy * ox) - cin = tx.shape[1] // (H * W) - - ridx = swizzle_col2im_index(oy, ox, cin, H, W) - # -1 has to be 0s - x = np.pad(tx.reshape(bs, -1), ((0,0),(0,1)))[:, ridx].sum(axis=2) - - """ - # col2im is just im2col in reverse, but np.add.at is SLOW - idx = get_im2col_index(oy, ox, cin, H, W) - x = np.zeros((bs, cin*OY*OX), dtype=tx.dtype) - idx = get_im2col_index(oy, ox, cin, H, W) - np.add.at(x, (slice(None), idx), tx.reshape(bs, -1)) - """ - - """ - # sadly, this is faster - x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype) - tx = tx.reshape(bs, oy, ox, cin, H, W) - for Y in range(oy): - for X in range(ox): - x[:, :, Y:Y+H, X:X+W] += tx[:, Y, X] - """ - - return x.reshape(bs, cin, OY, OX) -