mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
in_place_op
This commit is contained in:
@@ -11,20 +11,23 @@ def buffer_new(ctx, shape):
|
||||
def buffer_like(ctx, x):
|
||||
return buffer_new(ctx, x.shape)
|
||||
|
||||
def in_place_op(ctx, code, x, y):
|
||||
ret = buffer_like(ctx, x)
|
||||
prg = cl.Program(ctx.cl_ctx, """
|
||||
__kernel void add(
|
||||
__global const float *a_g, __global const float *b_g, __global float *res_g)
|
||||
{
|
||||
int gid = get_global_id(0);
|
||||
"""+code+"""
|
||||
}
|
||||
""").build()
|
||||
prg.add(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
|
||||
return ret
|
||||
|
||||
class Add(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ret = buffer_like(ctx, x)
|
||||
prg = cl.Program(ctx.cl_ctx, """
|
||||
__kernel void add(
|
||||
__global const float *a_g, __global const float *b_g, __global float *res_g)
|
||||
{
|
||||
int gid = get_global_id(0);
|
||||
res_g[gid] = a_g[gid] + b_g[gid];
|
||||
}
|
||||
""").build()
|
||||
prg.add(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
|
||||
return ret
|
||||
return in_place_op(ctx, 'res_g[gid] = a_g[gid] + b_g[gid];', x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
@@ -34,63 +37,34 @@ register('add', Add, gpu=True)
|
||||
class Sub(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ret = buffer_like(ctx, x)
|
||||
prg = cl.Program(ctx.cl_ctx, """
|
||||
__kernel void sub(
|
||||
__global const float *a_g, __global const float *b_g, __global float *res_g)
|
||||
{
|
||||
int gid = get_global_id(0);
|
||||
res_g[gid] = a_g[gid] - b_g[gid];
|
||||
}
|
||||
""").build()
|
||||
prg.sub(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
|
||||
return ret
|
||||
return in_place_op(ctx, 'res_g[gid] = a_g[gid] - b_g[gid];', x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# WRONG
|
||||
return grad_output, grad_output
|
||||
register('sub', Sub, gpu=True)
|
||||
|
||||
class Mul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ret = buffer_like(ctx, x)
|
||||
ctx.save_for_backward(x, y)
|
||||
|
||||
# HACK
|
||||
if y.shape == (1,):
|
||||
prg = cl.Program(ctx.cl_ctx, """
|
||||
__kernel void mul(
|
||||
__global const float *a_g, __global const float *b_g, __global float *res_g)
|
||||
{
|
||||
int gid = get_global_id(0);
|
||||
res_g[gid] = a_g[gid] * b_g[0];
|
||||
}
|
||||
""").build()
|
||||
prg.mul(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
|
||||
return in_place_op(ctx, 'res_g[gid] = a_g[gid] * b_g[0];', x, y)
|
||||
elif x.shape == y.shape:
|
||||
prg = cl.Program(ctx.cl_ctx, """
|
||||
__kernel void mul(
|
||||
__global const float *a_g, __global const float *b_g, __global float *res_g)
|
||||
{
|
||||
int gid = get_global_id(0);
|
||||
res_g[gid] = a_g[gid] * b_g[gid];
|
||||
}
|
||||
""").build()
|
||||
prg.mul(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
|
||||
return in_place_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', x, y)
|
||||
else:
|
||||
raise Exception("mismatched shapes %r %r" % (x.shape, y.shape))
|
||||
|
||||
ctx.save_for_backward(x, y, prg)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y,prg = ctx.saved_tensors
|
||||
gx = buffer_like(ctx, x)
|
||||
gy = buffer_like(ctx, y)
|
||||
prg.mul(ctx.cl_queue, [gx.size//4], None, y, grad_output, gx)
|
||||
prg.mul(ctx.cl_queue, [gy.size//4], None, x, grad_output, gy)
|
||||
return gx, gy
|
||||
x,y = ctx.saved_tensors
|
||||
return in_place_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', y, grad_output),\
|
||||
in_place_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', x, grad_output)
|
||||
register('mul', Mul, gpu=True)
|
||||
|
||||
class Sum(Function):
|
||||
|
||||
Reference in New Issue
Block a user