mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
group conv: forward pass works (#34)
* forward pass works * got the backward pass * okay, it's now a coho
This commit is contained in:
@@ -119,8 +119,10 @@ if __name__ == "__main__":
|
||||
# load cat image and preprocess
|
||||
from PIL import Image
|
||||
img = Image.open(io.BytesIO(fetch("https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg")))
|
||||
img = img.resize((224, 224))
|
||||
img = np.moveaxis(np.array(img), [2,0,1], [0,1,2])
|
||||
img = img.resize((398, 224))
|
||||
img = np.array(img)
|
||||
img = img[:, 87:-87]
|
||||
img = np.moveaxis(img, [2,0,1], [0,1,2])
|
||||
img = img.astype(np.float32).reshape(1,3,224,224)
|
||||
img /= 256
|
||||
img -= np.array([0.485, 0.456, 0.406]).reshape((1,-1,1,1))
|
||||
|
||||
@@ -48,11 +48,12 @@ class TestOps(unittest.TestCase):
|
||||
def test_conv2d(self):
|
||||
for bs in [1,8]:
|
||||
for cin in [1,3]:
|
||||
for H in [2,5]:
|
||||
for W in [2,3,5]:
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w).relu(), atol=2e-5, grad_atol=2e-6)
|
||||
for groups in [1,3] if cin == 3 else [1]:
|
||||
for H in [2,5]:
|
||||
for W in [2,3,5]:
|
||||
helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=2e-5, grad_atol=2e-6)
|
||||
|
||||
def test_strided_conv2d(self):
|
||||
bs = 4
|
||||
|
||||
@@ -155,21 +155,23 @@ class Conv2D(Function):
|
||||
def forward(ctx, x, w, stride=1, groups=1):
|
||||
if type(ctx.stride) == int:
|
||||
ctx.stride = (ctx.stride, ctx.stride)
|
||||
|
||||
cout,cin,H,W = w.shape
|
||||
if groups > 1:
|
||||
w = np.repeat(w, groups, axis=1) / groups
|
||||
tw = w.reshape(cout, -1).T
|
||||
ys,xs = ctx.stride
|
||||
bs,oy,ox = x.shape[0], (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs
|
||||
bs,cin_,oy,ox = x.shape[0], x.shape[1], (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs
|
||||
assert cin*ctx.groups == cin_
|
||||
assert cout % ctx.groups == 0
|
||||
rcout = cout//ctx.groups
|
||||
|
||||
ctx.save_for_backward(x, w)
|
||||
ret = np.zeros((bs, cout, oy, ox), dtype=w.dtype)
|
||||
for Y in range(oy):
|
||||
for X in range(ox):
|
||||
iY,iX = Y*ys, X*xs
|
||||
tx = x[:, :, iY:iY+H, iX:iX+W].reshape(bs, -1)
|
||||
ret[:, :, Y, X] = tx.dot(tw)
|
||||
|
||||
for g in range(ctx.groups):
|
||||
tw = w[g*rcout:(g*rcout+rcout)].reshape(rcout, -1).T
|
||||
for Y in range(oy):
|
||||
for X in range(ox):
|
||||
iY,iX = Y*ys, X*xs
|
||||
tx = x[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W].reshape(bs, -1)
|
||||
ret[:, g*rcout:(g*rcout+rcout), Y, X] += tx.dot(tw)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
@@ -177,17 +179,19 @@ class Conv2D(Function):
|
||||
bs,_,oy,ox = grad_output.shape
|
||||
x, w = ctx.saved_tensors
|
||||
cout,cin,H,W = w.shape
|
||||
tw = w.reshape(cout, -1)
|
||||
ys,xs = ctx.stride
|
||||
rcout = cout//ctx.groups
|
||||
|
||||
dx, dw = np.zeros_like(x), np.zeros_like(w)
|
||||
for Y in range(grad_output.shape[2]):
|
||||
for X in range(grad_output.shape[3]):
|
||||
iY,iX = Y*ys, X*xs
|
||||
gg = grad_output[:, :, Y, X]
|
||||
tx = x[:, :, iY:iY+H, iX:iX+W].reshape(x.shape[0], -1)
|
||||
dw += gg.T.dot(tx).reshape(dw.shape)
|
||||
dx[:, :, iY:iY+H, iX:iX+W] += gg.dot(tw).reshape(dx.shape[0], dx.shape[1], H, W)
|
||||
for g in range(ctx.groups):
|
||||
tw = w[g*rcout:(g*rcout+rcout)].reshape(rcout, -1)
|
||||
for Y in range(grad_output.shape[2]):
|
||||
for X in range(grad_output.shape[3]):
|
||||
iY,iX = Y*ys, X*xs
|
||||
gg = grad_output[:, g*rcout:(g*rcout+rcout), Y, X]
|
||||
tx = x[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W].reshape(x.shape[0], -1)
|
||||
dw[g*rcout:(g*rcout+rcout)] += gg.T.dot(tx).reshape((rcout,cin,H,W))
|
||||
dx[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W] += gg.dot(tw).reshape(dx.shape[0], cin, H, W)
|
||||
return dx, dw
|
||||
register('conv2d', Conv2D)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user