From 82fc842b40fe03742e4e56683788bf6fb2072956 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 2 Nov 2020 07:26:13 -0800 Subject: [PATCH] in_place_op --- tinygrad/opsgpu.py | 70 +++++++++++++++------------------------------- 1 file changed, 22 insertions(+), 48 deletions(-) diff --git a/tinygrad/opsgpu.py b/tinygrad/opsgpu.py index 7f624f0ebb..37553fb68e 100644 --- a/tinygrad/opsgpu.py +++ b/tinygrad/opsgpu.py @@ -11,20 +11,23 @@ def buffer_new(ctx, shape): def buffer_like(ctx, x): return buffer_new(ctx, x.shape) +def in_place_op(ctx, code, x, y): + ret = buffer_like(ctx, x) + prg = cl.Program(ctx.cl_ctx, """ + __kernel void add( + __global const float *a_g, __global const float *b_g, __global float *res_g) + { + int gid = get_global_id(0); + """+code+""" + } + """).build() + prg.add(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret) + return ret + class Add(Function): @staticmethod def forward(ctx, x, y): - ret = buffer_like(ctx, x) - prg = cl.Program(ctx.cl_ctx, """ - __kernel void add( - __global const float *a_g, __global const float *b_g, __global float *res_g) - { - int gid = get_global_id(0); - res_g[gid] = a_g[gid] + b_g[gid]; - } - """).build() - prg.add(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret) - return ret + return in_place_op(ctx, 'res_g[gid] = a_g[gid] + b_g[gid];', x, y) @staticmethod def backward(ctx, grad_output): @@ -34,63 +37,34 @@ register('add', Add, gpu=True) class Sub(Function): @staticmethod def forward(ctx, x, y): - ret = buffer_like(ctx, x) - prg = cl.Program(ctx.cl_ctx, """ - __kernel void sub( - __global const float *a_g, __global const float *b_g, __global float *res_g) - { - int gid = get_global_id(0); - res_g[gid] = a_g[gid] - b_g[gid]; - } - """).build() - prg.sub(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret) - return ret + return in_place_op(ctx, 'res_g[gid] = a_g[gid] - b_g[gid];', x, y) @staticmethod def backward(ctx, grad_output): + # WRONG return grad_output, grad_output register('sub', Sub, gpu=True) class Mul(Function): @staticmethod def forward(ctx, x, y): - ret = buffer_like(ctx, x) + ctx.save_for_backward(x, y) # HACK if y.shape == (1,): - prg = cl.Program(ctx.cl_ctx, """ - __kernel void mul( - __global const float *a_g, __global const float *b_g, __global float *res_g) - { - int gid = get_global_id(0); - res_g[gid] = a_g[gid] * b_g[0]; - } - """).build() - prg.mul(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret) + return in_place_op(ctx, 'res_g[gid] = a_g[gid] * b_g[0];', x, y) elif x.shape == y.shape: - prg = cl.Program(ctx.cl_ctx, """ - __kernel void mul( - __global const float *a_g, __global const float *b_g, __global float *res_g) - { - int gid = get_global_id(0); - res_g[gid] = a_g[gid] * b_g[gid]; - } - """).build() - prg.mul(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret) + return in_place_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', x, y) else: raise Exception("mismatched shapes %r %r" % (x.shape, y.shape)) - ctx.save_for_backward(x, y, prg) return ret @staticmethod def backward(ctx, grad_output): - x,y,prg = ctx.saved_tensors - gx = buffer_like(ctx, x) - gy = buffer_like(ctx, y) - prg.mul(ctx.cl_queue, [gx.size//4], None, y, grad_output, gx) - prg.mul(ctx.cl_queue, [gy.size//4], None, x, grad_output, gy) - return gx, gy + x,y = ctx.saved_tensors + return in_place_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', y, grad_output),\ + in_place_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', x, grad_output) register('mul', Mul, gpu=True) class Sum(Function):