fix gpu sub

This commit is contained in:
George Hotz
2020-11-02 08:18:58 -08:00
parent 231c1134bd
commit 8766346187

View File

@@ -36,7 +36,7 @@ def unary_op(ctx, code, x):
__global const float *a_g, __global float *res_g)
{
int gid = get_global_id(0);
res_g[gid] = min(a_g[gid], (float)0.);
"""+code+"""
}
""")
prg.relu(ctx.cl_queue, [np.prod(ret.shape)], None, x, ret)
@@ -59,8 +59,8 @@ class Sub(Function):
@staticmethod
def backward(ctx, grad_output):
# WRONG
return grad_output, grad_output
not_grad_output = unary_op(ctx, 'res_g[gid] = -a_g[gid];', grad_output)
return grad_output, not_grad_output
register('sub', Sub, gpu=True)
class Mul(Function):