cleanup LogSoftmax

This commit is contained in:
George Hotz
2020-11-15 20:49:57 -08:00
parent d1441de3a6
commit 1207fe4c7d

View File

@@ -39,7 +39,7 @@ def cl_subsample_krnl_build(cl_ctx, iter_op, result_op, decls=''):
"""+decls+""";
for (uint j=0; j<ksz.y; ++j) {
for (uint i=0; i<ksz.x; ++i) {
int iid = (gid.x*stride.x+i) + isize.x*((gid.y*stride.y+j) + isize.y*gid.z);
int iid = (gid.x*stride.x+i) + isize.x*((gid.y*stride.y+j) + isize.y*gid.z);
if (gid.x*stride.x+i < isize.x && gid.y*stride.y+j < isize.y) {
"""+iter_op+""";
}
@@ -405,25 +405,9 @@ class LogSoftmax(Function):
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
grad_input = buffer_like(ctx, grad_output)
prg = clbuild(ctx.cl_ctx, """
__kernel void lsmsub2(__global const float *grad_output, __global const float *output, int sz,
__global float *grad_input) {
int gidsz = get_global_id(0)*sz;
int gid2 = get_global_id(1);
float acc = 0.0;
for (int x = 0; x < sz; x++) {
acc += grad_output[gidsz + x];
}
grad_input[gidsz + gid2] = grad_output[gidsz + gid2] - exp(output[gidsz + gid2]) * acc;
}""")
prg.lsmsub2(ctx.cl_queue, [grad_output.shape[0], grad_output.shape[1]], None,
grad_output, output, i32(grad_output.shape[1]), grad_input)
return grad_input
lsum = reduce_op(ctx, "out += a", "out", grad_output, (grad_output.shape[0],1))
texp = binary_op(ctx, "exp(a) * b", output, lsum)
return binary_op(ctx, "a - b", grad_output, texp)
register('logsoftmax', LogSoftmax, gpu=True)
# ************* conv ops *************
@@ -506,14 +490,15 @@ class Conv2D(Function):
int y = get_global_id(1); // range 0-H
int x = get_global_id(2); // range 0-W
// tensx = (bs, groups*cin, iy, ix)
// tensx = (bs, groups*cin, iy, ix)
// tensw = (groups*rcout, cin, H, W)
// ggg = (bs, groups*rout, oy, ox)
float acc = 0.0;
for (int Y = 0; Y < oy; Y++) {
for (int X = 0; X < ox; X++) {
for (int B = 0; B < bs; B++) {
acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X]*tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x];
acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x];
}
}
}
@@ -531,7 +516,9 @@ class Conv2D(Function):
for (int X = 0; X < ox; X++) {
for (int y = 0; y < H; y++) {
for (int x = 0; x < W; x++) {
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x]+= ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X]*tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x] += \
ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
}
}
}