mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cleanup LogSoftmax
This commit is contained in:
@@ -39,7 +39,7 @@ def cl_subsample_krnl_build(cl_ctx, iter_op, result_op, decls=''):
|
||||
"""+decls+""";
|
||||
for (uint j=0; j<ksz.y; ++j) {
|
||||
for (uint i=0; i<ksz.x; ++i) {
|
||||
int iid = (gid.x*stride.x+i) + isize.x*((gid.y*stride.y+j) + isize.y*gid.z);
|
||||
int iid = (gid.x*stride.x+i) + isize.x*((gid.y*stride.y+j) + isize.y*gid.z);
|
||||
if (gid.x*stride.x+i < isize.x && gid.y*stride.y+j < isize.y) {
|
||||
"""+iter_op+""";
|
||||
}
|
||||
@@ -405,25 +405,9 @@ class LogSoftmax(Function):
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
output, = ctx.saved_tensors
|
||||
|
||||
grad_input = buffer_like(ctx, grad_output)
|
||||
prg = clbuild(ctx.cl_ctx, """
|
||||
__kernel void lsmsub2(__global const float *grad_output, __global const float *output, int sz,
|
||||
__global float *grad_input) {
|
||||
int gidsz = get_global_id(0)*sz;
|
||||
int gid2 = get_global_id(1);
|
||||
|
||||
float acc = 0.0;
|
||||
for (int x = 0; x < sz; x++) {
|
||||
acc += grad_output[gidsz + x];
|
||||
}
|
||||
|
||||
grad_input[gidsz + gid2] = grad_output[gidsz + gid2] - exp(output[gidsz + gid2]) * acc;
|
||||
}""")
|
||||
prg.lsmsub2(ctx.cl_queue, [grad_output.shape[0], grad_output.shape[1]], None,
|
||||
grad_output, output, i32(grad_output.shape[1]), grad_input)
|
||||
|
||||
return grad_input
|
||||
lsum = reduce_op(ctx, "out += a", "out", grad_output, (grad_output.shape[0],1))
|
||||
texp = binary_op(ctx, "exp(a) * b", output, lsum)
|
||||
return binary_op(ctx, "a - b", grad_output, texp)
|
||||
register('logsoftmax', LogSoftmax, gpu=True)
|
||||
|
||||
# ************* conv ops *************
|
||||
@@ -506,14 +490,15 @@ class Conv2D(Function):
|
||||
int y = get_global_id(1); // range 0-H
|
||||
int x = get_global_id(2); // range 0-W
|
||||
|
||||
// tensx = (bs, groups*cin, iy, ix)
|
||||
// tensx = (bs, groups*cin, iy, ix)
|
||||
// tensw = (groups*rcout, cin, H, W)
|
||||
// ggg = (bs, groups*rout, oy, ox)
|
||||
float acc = 0.0;
|
||||
for (int Y = 0; Y < oy; Y++) {
|
||||
for (int X = 0; X < ox; X++) {
|
||||
for (int B = 0; B < bs; B++) {
|
||||
acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X]*tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x];
|
||||
acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
|
||||
tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -531,7 +516,9 @@ class Conv2D(Function):
|
||||
for (int X = 0; X < ox; X++) {
|
||||
for (int y = 0; y < H; y++) {
|
||||
for (int x = 0; x < W; x++) {
|
||||
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x]+= ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X]*tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x] += \
|
||||
ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
|
||||
tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user