mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Conv2D backward on GPU (#93)
* to make it work locally * definitely not working * Conv2D GPU passes some of the tests * Conv2D GPU passes more of the tests * passes some tests and mnist * removed unecessary code * Conv2D Backpass works * wrong test_ops.py * white space + test backward * ereased useless code * removed default argument * long lines
This commit is contained in:
@@ -74,7 +74,7 @@ class TestOps(unittest.TestCase):
|
||||
with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W):
|
||||
helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), gpu=self.gpu, grad_rtol=1e-5, forward_only=self.gpu)
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), gpu=self.gpu, grad_rtol=1e-5)
|
||||
|
||||
def test_strided_conv2d(self):
|
||||
bs = 4
|
||||
|
||||
@@ -18,7 +18,7 @@ def buffer_zeros(ctx, shape):
|
||||
def buffer_like(ctx, x):
|
||||
return buffer_new(ctx, x.shape)
|
||||
|
||||
@functools.lru_cache
|
||||
@functools.lru_cache()
|
||||
def clbuild(cl_ctx, prg):
|
||||
return cl.Program(cl_ctx, prg).build()
|
||||
|
||||
@@ -434,6 +434,8 @@ class Conv2D(Function):
|
||||
assert cin*ctx.groups == cin_
|
||||
assert cout % ctx.groups == 0
|
||||
rcout = cout//ctx.groups
|
||||
|
||||
ctx.save_for_backward(x,w)
|
||||
|
||||
# output buffer
|
||||
ret = buffer_new(ctx, (bs, cout, oy, ox))
|
||||
@@ -478,7 +480,66 @@ class Conv2D(Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise Exception("not implemented")
|
||||
bs,_,oy,ox = grad_output.shape
|
||||
x, w = ctx.saved_tensors
|
||||
cout,cin,H,W = w.shape
|
||||
ys,xs = ctx.stride
|
||||
bs,cin_,iy,ix = x.shape
|
||||
oy,ox = (iy-(H-ys))//ys, (ix-(W-xs))//xs
|
||||
assert cin*ctx.groups == cin_
|
||||
assert cout % ctx.groups == 0
|
||||
rcout = cout//ctx.groups
|
||||
|
||||
dx = buffer_zeros(ctx, (bs, cin_, iy, ix))
|
||||
dw = buffer_new(ctx, (cout, cin, H, W))
|
||||
|
||||
prg = clbuild(ctx.cl_ctx, """
|
||||
__kernel void convw(__global const float *tensx, __global const float *ggg, __global float *dw,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
|
||||
|
||||
int g = get_global_id(0)/(rcout*cin) ; // range 0-groups
|
||||
int c = (get_global_id(0)/(cin)) %rcout; // range 0-rcout
|
||||
int ci = get_global_id(0) % cin; // range 0-cin
|
||||
int y = get_global_id(1); // range 0-H
|
||||
int x = get_global_id(2); // range 0-W
|
||||
|
||||
// tensx = (bs, groups*cin, iy, ix)
|
||||
// tensw = (groups*rcout, cin, H, W)
|
||||
// ggg = (bs, groups*rout, oy, ox)
|
||||
float acc = 0.0;
|
||||
for (int Y = 0; Y < oy; Y++) {
|
||||
for (int X = 0; X < ox; X++) {
|
||||
for (int B = 0; B < bs; B++) {
|
||||
acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X]*tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x];
|
||||
}
|
||||
}
|
||||
}
|
||||
dw[get_global_id(0)*H*W + y*W + x] = acc;
|
||||
}
|
||||
__kernel void convx(__global const float *tensw, __global const float *ggg, __global float *dx,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
|
||||
|
||||
int B = get_global_id(0);
|
||||
int g = get_global_id(1);
|
||||
int ci = get_global_id(2);
|
||||
|
||||
for (int c = 0; c < rcout; c++) {
|
||||
for (int Y = 0; Y < oy; Y++) {
|
||||
for (int X = 0; X < ox; X++) {
|
||||
for (int y = 0; y < H; y++) {
|
||||
for (int x = 0; x < W; x++) {
|
||||
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x]+= ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X]*tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
prg.convw(ctx.cl_queue, [ctx.groups*rcout*cin, H, W], None, x, grad_output, dw, np.int32(H), np.int32(W), np.int32(ctx.groups),
|
||||
np.int32(rcout), np.int32(cin), np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix), np.int32(ys), np.int32(xs), np.int32(bs))
|
||||
prg.convx(ctx.cl_queue, [bs, ctx.groups, cin], None, w, grad_output, dx, np.int32(H), np.int32(W), np.int32(ctx.groups),
|
||||
np.int32(rcout), np.int32(cin), np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix), np.int32(ys), np.int32(xs), np.int32(bs))
|
||||
return dx, dw
|
||||
register('conv2d', Conv2D, gpu=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user