Conv2D backward on GPU (#93)

* to make it work locally

* definitely not working

* Conv2D GPU passes some of the tests

* Conv2D GPU passes more of the tests

* passes some tests and mnist

* removed unecessary code

* Conv2D Backpass works

* wrong test_ops.py

* white space + test backward

* ereased useless code

* removed default argument

* long lines
This commit is contained in:
Marcel Bischoff
2020-11-10 19:07:33 -05:00
committed by GitHub
parent 5577b9d3a0
commit 7bb803c5e0
2 changed files with 65 additions and 4 deletions

View File

@@ -74,7 +74,7 @@ class TestOps(unittest.TestCase):
with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W):
helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), gpu=self.gpu, grad_rtol=1e-5, forward_only=self.gpu)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), gpu=self.gpu, grad_rtol=1e-5)
def test_strided_conv2d(self):
bs = 4

View File

@@ -18,7 +18,7 @@ def buffer_zeros(ctx, shape):
def buffer_like(ctx, x):
return buffer_new(ctx, x.shape)
@functools.lru_cache
@functools.lru_cache()
def clbuild(cl_ctx, prg):
return cl.Program(cl_ctx, prg).build()
@@ -434,6 +434,8 @@ class Conv2D(Function):
assert cin*ctx.groups == cin_
assert cout % ctx.groups == 0
rcout = cout//ctx.groups
ctx.save_for_backward(x,w)
# output buffer
ret = buffer_new(ctx, (bs, cout, oy, ox))
@@ -478,7 +480,66 @@ class Conv2D(Function):
@staticmethod
def backward(ctx, grad_output):
raise Exception("not implemented")
bs,_,oy,ox = grad_output.shape
x, w = ctx.saved_tensors
cout,cin,H,W = w.shape
ys,xs = ctx.stride
bs,cin_,iy,ix = x.shape
oy,ox = (iy-(H-ys))//ys, (ix-(W-xs))//xs
assert cin*ctx.groups == cin_
assert cout % ctx.groups == 0
rcout = cout//ctx.groups
dx = buffer_zeros(ctx, (bs, cin_, iy, ix))
dw = buffer_new(ctx, (cout, cin, H, W))
prg = clbuild(ctx.cl_ctx, """
__kernel void convw(__global const float *tensx, __global const float *ggg, __global float *dw,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
int g = get_global_id(0)/(rcout*cin) ; // range 0-groups
int c = (get_global_id(0)/(cin)) %rcout; // range 0-rcout
int ci = get_global_id(0) % cin; // range 0-cin
int y = get_global_id(1); // range 0-H
int x = get_global_id(2); // range 0-W
// tensx = (bs, groups*cin, iy, ix)
// tensw = (groups*rcout, cin, H, W)
// ggg = (bs, groups*rout, oy, ox)
float acc = 0.0;
for (int Y = 0; Y < oy; Y++) {
for (int X = 0; X < ox; X++) {
for (int B = 0; B < bs; B++) {
acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X]*tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x];
}
}
}
dw[get_global_id(0)*H*W + y*W + x] = acc;
}
__kernel void convx(__global const float *tensw, __global const float *ggg, __global float *dx,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
int B = get_global_id(0);
int g = get_global_id(1);
int ci = get_global_id(2);
for (int c = 0; c < rcout; c++) {
for (int Y = 0; Y < oy; Y++) {
for (int X = 0; X < ox; X++) {
for (int y = 0; y < H; y++) {
for (int x = 0; x < W; x++) {
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x]+= ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X]*tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
}
}
}
}
}
}
""")
prg.convw(ctx.cl_queue, [ctx.groups*rcout*cin, H, W], None, x, grad_output, dw, np.int32(H), np.int32(W), np.int32(ctx.groups),
np.int32(rcout), np.int32(cin), np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix), np.int32(ys), np.int32(xs), np.int32(bs))
prg.convx(ctx.cl_queue, [bs, ctx.groups, cin], None, w, grad_output, dx, np.int32(H), np.int32(W), np.int32(ctx.groups),
np.int32(rcout), np.int32(cin), np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix), np.int32(ys), np.int32(xs), np.int32(bs))
return dx, dw
register('conv2d', Conv2D, gpu=True)