in_place_op

This commit is contained in:
George Hotz
2020-11-02 07:26:13 -08:00
parent 1d793b8571
commit 82fc842b40

View File

@@ -11,20 +11,23 @@ def buffer_new(ctx, shape):
def buffer_like(ctx, x):
return buffer_new(ctx, x.shape)
def in_place_op(ctx, code, x, y):
ret = buffer_like(ctx, x)
prg = cl.Program(ctx.cl_ctx, """
__kernel void add(
__global const float *a_g, __global const float *b_g, __global float *res_g)
{
int gid = get_global_id(0);
"""+code+"""
}
""").build()
prg.add(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
return ret
class Add(Function):
@staticmethod
def forward(ctx, x, y):
ret = buffer_like(ctx, x)
prg = cl.Program(ctx.cl_ctx, """
__kernel void add(
__global const float *a_g, __global const float *b_g, __global float *res_g)
{
int gid = get_global_id(0);
res_g[gid] = a_g[gid] + b_g[gid];
}
""").build()
prg.add(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
return ret
return in_place_op(ctx, 'res_g[gid] = a_g[gid] + b_g[gid];', x, y)
@staticmethod
def backward(ctx, grad_output):
@@ -34,63 +37,34 @@ register('add', Add, gpu=True)
class Sub(Function):
@staticmethod
def forward(ctx, x, y):
ret = buffer_like(ctx, x)
prg = cl.Program(ctx.cl_ctx, """
__kernel void sub(
__global const float *a_g, __global const float *b_g, __global float *res_g)
{
int gid = get_global_id(0);
res_g[gid] = a_g[gid] - b_g[gid];
}
""").build()
prg.sub(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
return ret
return in_place_op(ctx, 'res_g[gid] = a_g[gid] - b_g[gid];', x, y)
@staticmethod
def backward(ctx, grad_output):
# WRONG
return grad_output, grad_output
register('sub', Sub, gpu=True)
class Mul(Function):
@staticmethod
def forward(ctx, x, y):
ret = buffer_like(ctx, x)
ctx.save_for_backward(x, y)
# HACK
if y.shape == (1,):
prg = cl.Program(ctx.cl_ctx, """
__kernel void mul(
__global const float *a_g, __global const float *b_g, __global float *res_g)
{
int gid = get_global_id(0);
res_g[gid] = a_g[gid] * b_g[0];
}
""").build()
prg.mul(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
return in_place_op(ctx, 'res_g[gid] = a_g[gid] * b_g[0];', x, y)
elif x.shape == y.shape:
prg = cl.Program(ctx.cl_ctx, """
__kernel void mul(
__global const float *a_g, __global const float *b_g, __global float *res_g)
{
int gid = get_global_id(0);
res_g[gid] = a_g[gid] * b_g[gid];
}
""").build()
prg.mul(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
return in_place_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', x, y)
else:
raise Exception("mismatched shapes %r %r" % (x.shape, y.shape))
ctx.save_for_backward(x, y, prg)
return ret
@staticmethod
def backward(ctx, grad_output):
x,y,prg = ctx.saved_tensors
gx = buffer_like(ctx, x)
gy = buffer_like(ctx, y)
prg.mul(ctx.cl_queue, [gx.size//4], None, y, grad_output, gx)
prg.mul(ctx.cl_queue, [gy.size//4], None, x, grad_output, gy)
return gx, gy
x,y = ctx.saved_tensors
return in_place_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', y, grad_output),\
in_place_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', x, grad_output)
register('mul', Mul, gpu=True)
class Sum(Function):