mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
who loves speeeeed
This commit is contained in:
@@ -461,7 +461,10 @@ class Conv2D(Function):
|
||||
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs) {
|
||||
|
||||
int B = get_global_id(0); // range 0-bs
|
||||
int B = get_global_id(0)/(groups*rcout); // range 0-bs
|
||||
int g = (get_global_id(0)/rcout)%groups;
|
||||
int c = get_global_id(0) % rcout;
|
||||
|
||||
int Y = get_global_id(1); // range 0-oy
|
||||
int X = get_global_id(2); // range 0-ox
|
||||
int IY = Y*ys;
|
||||
@@ -470,24 +473,20 @@ class Conv2D(Function):
|
||||
// input = (bs, groups, cin, iy, ix)
|
||||
// weight = (groups, rcout, cin, H, W)
|
||||
// output = (bs, groups, rcout, oy, ox)
|
||||
for (int g = 0; g < groups; g++) {
|
||||
for (int c = 0; c < rcout; c++) {
|
||||
float acc = 0.0;
|
||||
for (int ci = 0; ci < cin; ci++) {
|
||||
for (int y = IY; y < IY+H; y++) {
|
||||
for (int x = IX; x < IX+W; x++) {
|
||||
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \
|
||||
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX)];
|
||||
}
|
||||
}
|
||||
float acc = 0.0;
|
||||
for (int ci = 0; ci < cin; ci++) {
|
||||
for (int y = IY; y < IY+H; y++) {
|
||||
for (int x = IX; x < IX+W; x++) {
|
||||
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \
|
||||
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX)];
|
||||
}
|
||||
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
|
||||
}
|
||||
}
|
||||
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
|
||||
}
|
||||
""")
|
||||
|
||||
prg.conv(ctx.cl_queue, [bs, oy, ox], None,
|
||||
prg.conv(ctx.cl_queue, [bs*groups*rcout, oy, ox], None,
|
||||
x, w, ret,
|
||||
np.int32(H), np.int32(W),
|
||||
np.int32(groups), np.int32(rcout), np.int32(cin),
|
||||
|
||||
Reference in New Issue
Block a user