who loves speeeeed

This commit is contained in:
George Hotz
2020-11-07 11:36:52 -08:00
parent e6c8321e5b
commit b1ca4dd327

View File

@@ -461,7 +461,10 @@ class Conv2D(Function):
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs) {
int B = get_global_id(0); // range 0-bs
int B = get_global_id(0)/(groups*rcout); // range 0-bs
int g = (get_global_id(0)/rcout)%groups;
int c = get_global_id(0) % rcout;
int Y = get_global_id(1); // range 0-oy
int X = get_global_id(2); // range 0-ox
int IY = Y*ys;
@@ -470,24 +473,20 @@ class Conv2D(Function):
// input = (bs, groups, cin, iy, ix)
// weight = (groups, rcout, cin, H, W)
// output = (bs, groups, rcout, oy, ox)
for (int g = 0; g < groups; g++) {
for (int c = 0; c < rcout; c++) {
float acc = 0.0;
for (int ci = 0; ci < cin; ci++) {
for (int y = IY; y < IY+H; y++) {
for (int x = IX; x < IX+W; x++) {
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX)];
}
}
float acc = 0.0;
for (int ci = 0; ci < cin; ci++) {
for (int y = IY; y < IY+H; y++) {
for (int x = IX; x < IX+W; x++) {
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX)];
}
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
}
}
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
}
""")
prg.conv(ctx.cl_queue, [bs, oy, ox], None,
prg.conv(ctx.cl_queue, [bs*groups*rcout, oy, ox], None,
x, w, ret,
np.int32(H), np.int32(W),
np.int32(groups), np.int32(rcout), np.int32(cin),