mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix gpu sub
This commit is contained in:
@@ -36,7 +36,7 @@ def unary_op(ctx, code, x):
|
||||
__global const float *a_g, __global float *res_g)
|
||||
{
|
||||
int gid = get_global_id(0);
|
||||
res_g[gid] = min(a_g[gid], (float)0.);
|
||||
"""+code+"""
|
||||
}
|
||||
""")
|
||||
prg.relu(ctx.cl_queue, [np.prod(ret.shape)], None, x, ret)
|
||||
@@ -59,8 +59,8 @@ class Sub(Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# WRONG
|
||||
return grad_output, grad_output
|
||||
not_grad_output = unary_op(ctx, 'res_g[gid] = -a_g[gid];', grad_output)
|
||||
return grad_output, not_grad_output
|
||||
register('sub', Sub, gpu=True)
|
||||
|
||||
class Mul(Function):
|
||||
|
||||
Reference in New Issue
Block a user