From b1ca4dd327919caf815441737e9647cd343067c5 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 7 Nov 2020 11:36:52 -0800 Subject: [PATCH] who loves speeeeed --- tinygrad/opsgpu.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tinygrad/opsgpu.py b/tinygrad/opsgpu.py index 3fdff6de24..aaf9467ef9 100644 --- a/tinygrad/opsgpu.py +++ b/tinygrad/opsgpu.py @@ -461,7 +461,10 @@ class Conv2D(Function): __kernel void conv(__global const float *input, __global const float *weight, __global float *output, int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs) { - int B = get_global_id(0); // range 0-bs + int B = get_global_id(0)/(groups*rcout); // range 0-bs + int g = (get_global_id(0)/rcout)%groups; + int c = get_global_id(0) % rcout; + int Y = get_global_id(1); // range 0-oy int X = get_global_id(2); // range 0-ox int IY = Y*ys; @@ -470,24 +473,20 @@ class Conv2D(Function): // input = (bs, groups, cin, iy, ix) // weight = (groups, rcout, cin, H, W) // output = (bs, groups, rcout, oy, ox) - for (int g = 0; g < groups; g++) { - for (int c = 0; c < rcout; c++) { - float acc = 0.0; - for (int ci = 0; ci < cin; ci++) { - for (int y = IY; y < IY+H; y++) { - for (int x = IX; x < IX+W; x++) { - acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \ - weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX)]; - } - } + float acc = 0.0; + for (int ci = 0; ci < cin; ci++) { + for (int y = IY; y < IY+H; y++) { + for (int x = IX; x < IX+W; x++) { + acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \ + weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX)]; } - output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc; } } + output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc; } """) - prg.conv(ctx.cl_queue, [bs, oy, ox], None, + prg.conv(ctx.cl_queue, [bs*groups*rcout, oy, ox], None, x, w, ret, np.int32(H), np.int32(W), np.int32(groups), np.int32(rcout), np.int32(cin),