* init GPU supsample retbuf to 0

* reduce GPU kernel source lines

ref: #94
This commit is contained in:
Ryan Neph
2020-11-10 01:20:04 -08:00
committed by GitHub
parent 55012d21bb
commit 56f71ae8e5

View File

@@ -29,10 +29,8 @@ def uint2(x, y):
def cl_subsample_krnl_build(cl_ctx, iter_op, result_op, init_val=0):
prg = """
__kernel void subsample(
__global float *output, __global const float *input, uint2 osize, uint2 isize, uint2 kernel_size,
uint2 stride, int nelem
) {
__kernel void subsample(__global float *output, __global const float *input, uint2 osize, uint2 isize,
uint2 kernel_size, uint2 stride, int nelem) {
int3 gid = (int3)(get_global_id(2), get_global_id(1), get_global_id(0));
int oid = gid.x + osize.x*(gid.y + osize.y*gid.z);
float group_res = """+str(init_val)+""";
@@ -45,15 +43,14 @@ def cl_subsample_krnl_build(cl_ctx, iter_op, result_op, init_val=0):
}
}
output[oid] = """+result_op+""";
}
"""
}"""
return clbuild(cl_ctx, prg)
def subsample_op(ctx, input, kernel_size, stride, iter_op, result_op, init_val=0):
py, px = stride
N, C, Yin, Xin = input.shape
Yout, Xout = (Yin-kernel_size[0])//py+1, (Xin-kernel_size[1])//px+1
ret = buffer_new(ctx, (N, C, Yout, Xout))
ret = buffer_zeros(ctx, (N, C, Yout, Xout))
prg = cl_subsample_krnl_build(ctx.cl_ctx, iter_op, result_op, init_val=init_val)
prg.subsample(ctx.cl_queue, (N*C, Yout, Xout), None,
ret, input, uint2(Xout, Yout), uint2(Xin, Yin),
@@ -63,17 +60,15 @@ def subsample_op(ctx, input, kernel_size, stride, iter_op, result_op, init_val=0
def cl_supsample_krnl_build(cl_ctx, result_op):
prg = """
__kernel void supsample(
__global float *output, __global const float *input, uint2 osize, uint2 isize, uint2 kernel_size, int nelem
) {
__kernel void supsample(__global float *output, __global const float *input, uint2 osize, uint2 isize,
uint2 kernel_size, int nelem) {
int3 gid = (int3)(get_global_id(2), get_global_id(1), get_global_id(0));
int oid = gid.x + osize.x*(gid.y + osize.y*gid.z);
int iid = (gid.x/kernel_size.x) + isize.x*((gid.y/kernel_size.y) + isize.y*gid.z);
if (gid.x/kernel_size.x < isize.x && gid.y/kernel_size.y < isize.y) {
output[oid] = """+result_op+""";
}
}
"""
}"""
return clbuild(cl_ctx, prg)
def supersample_op(ctx, input, out_shape, kernel_size, result_op):
@@ -106,38 +101,31 @@ def binary_op(ctx, code, x, y):
raise Exception("binary op shape mismatch: %r != %r" % (x.shape, y.shape))
ret = buffer_like(ctx, x if np.prod(x.shape) >= np.prod(y.shape) else y)
prg = clbuild(ctx.cl_ctx, """
__kernel void binop(
__global const float *a_g, __global const float *b_g, __global float *res_g, int xdiv, int ydiv)
{
__kernel void binop(__global const float *a_g, __global const float *b_g, __global float *res_g,
int xdiv, int ydiv) {
int gid = get_global_id(0);
float a = a_g[gid/xdiv];
float b = b_g[gid/ydiv];
res_g[gid] = """+code+""";
}
""")
}""")
prg.binop(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret, np.int32(xdiv), np.int32(ydiv))
return ret
def unary_op(ctx, code, x):
ret = buffer_like(ctx, x)
prg = clbuild(ctx.cl_ctx, """
__kernel void unop(
__global const float *a_g, __global float *res_g)
{
__kernel void unop(__global const float *a_g, __global float *res_g) {
int gid = get_global_id(0);
float a = a_g[gid];
res_g[gid] = """+code+""";
}
""")
}""")
prg.unop(ctx.cl_queue, [np.prod(ret.shape)], None, x, ret)
return ret
def reduce_op(ctx, code, code2, input, osize):
ret = buffer_new(ctx, osize)
prg = clbuild(ctx.cl_ctx, """
__kernel void reduce(
__global const float *a_g, int sz, __global float *res_g)
{
__kernel void reduce(__global const float *a_g, int sz, __global float *res_g) {
int gid = get_global_id(0);
float out = 0.0;
for (int x = 0; x < sz; x++) {
@@ -145,8 +133,7 @@ def reduce_op(ctx, code, code2, input, osize):
"""+code+""";
}
res_g[gid] = """+code2+""";
}
""")
}""")
prg.reduce(ctx.cl_queue, osize, None, input, np.int32(np.prod(input.shape) // np.prod(osize)), ret)
return ret
@@ -214,13 +201,10 @@ class Sum(Function):
ret = buffer_like(ctx, input)
prg = clbuild(ctx.cl_ctx, """
__kernel void fill(
__global const float *a_g, __global float *res_g)
{
__kernel void fill(__global const float *a_g, __global float *res_g) {
int gid = get_global_id(0);
res_g[gid] = a_g[0];
}
""")
}""")
prg.fill(ctx.cl_queue, [np.prod(ret.shape)], None, grad_output, ret)
return ret
@@ -243,8 +227,7 @@ class Dot(Function):
__global float *res,
int is0, int is1, int msize,
int ws0, int ws1, int osize
)
{
) {
int X = get_global_id(0); // isize
int Y = get_global_id(1); // osize
@@ -254,8 +237,7 @@ class Dot(Function):
}
res[X * osize + Y] = ret;
}
""")
}""")
ctx.save_for_backward(input, weight, prg)
# (isize,msize) x (msize,osize) = (isize,osize)
prg.matmul(ctx.cl_queue, [isize, osize], None,
@@ -298,11 +280,8 @@ class Pad2D(Function):
ret = buffer_zeros(ctx, (bs, cin, oy, ox))
prg = clbuild(ctx.cl_ctx, """
__kernel void pad2d(
__global const float *input, __global float *output,
int py, int px, int oy, int ox, int iy, int ix
)
{
__kernel void pad2d(__global const float *input, __global float *output,
int py, int px, int oy, int ox, int iy, int ix) {
int BC = get_global_id(0);
int Y = get_global_id(1);
int X = get_global_id(2);
@@ -311,8 +290,7 @@ class Pad2D(Function):
int optr = BC*oy*ox + (Y+py)*ox + px + X;
output[optr] = input[iptr];
}
""")
}""")
ctx.save_for_backward(padding)
prg.pad2d(ctx.cl_queue, [bs*cin, iy, ix], None,
x, ret,
@@ -328,11 +306,8 @@ class Pad2D(Function):
oy, ox = iy - padding[2] - padding[3], ix - padding[0] - padding[1]
ret = buffer_new(ctx, (bs, cin, oy, ox))
prg = clbuild(ctx.cl_ctx, """
__kernel void pad2d(
__global const float *input, __global float *output,
int cin, int py, int px, int oy, int ox, int iy, int ix
)
{
__kernel void pad2d(__global const float *input, __global float *output,
int cin, int py, int px, int oy, int ox, int iy, int ix) {
int B = get_global_id(0);
int C = get_global_id(1);
int Y = get_global_id(2);
@@ -343,8 +318,7 @@ class Pad2D(Function):
for (int x = 0; x < ox; x++) {
output[optr+x] = input[iptr+x];
}
}
""")
}""")
prg.pad2d(ctx.cl_queue, [bs, cin, oy], None,
grad_output, ret,
np.int32(cin), np.int32(padding[2]), np.int32(padding[0]),
@@ -449,9 +423,8 @@ class LogSoftmax(Function):
grad_input = buffer_like(ctx, grad_output)
prg = clbuild(ctx.cl_ctx, """
__kernel void lsmsub2(
__global const float *grad_output, __global const float *output, int sz, __global float *grad_input)
{
__kernel void lsmsub2(__global const float *grad_output, __global const float *output, int sz,
__global float *grad_input) {
int gid = get_global_id(0);
int gidsz = gid*sz;
int gid2 = get_global_id(1);
@@ -463,8 +436,7 @@ class LogSoftmax(Function):
}
grad_input[gidsz + gid2] = grad_output[gidsz + gid2] - exp(output[gidsz + gid2]) * acc;
}
""")
}""")
prg.lsmsub2(ctx.cl_queue, [grad_output.shape[0], grad_output.shape[1]], None,
grad_output, output, np.int32(grad_output.shape[1]), grad_input)
@@ -515,8 +487,7 @@ class Conv2D(Function):
}
}
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
}
""")
}""")
prg.conv(ctx.cl_queue, [bs*groups*rcout, oy, ox], None,
x, w, ret,