mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
unbroadcast GPU template
This commit is contained in:
@@ -80,7 +80,7 @@ class TestOps(unittest.TestCase):
|
||||
for shapes in [((1,32,32,32), (1,32,1,1)), ((5,13,24,16,2), (1,13,24,1,1)),
|
||||
((4,1), (4,5)), ((1,4), (5,4))]:
|
||||
with self.subTest(op=torch_op.__name__, shapes=shapes):
|
||||
helper_test_op(shapes, torch_op, tinygrad_op, gpu=self.gpu, forward_only=True)
|
||||
helper_test_op(shapes, torch_op, tinygrad_op, gpu=self.gpu, forward_only=self.gpu)
|
||||
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), gpu=self.gpu)
|
||||
|
||||
@@ -12,7 +12,7 @@ def unbroadcast(out, in_sh):
|
||||
class Add(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape,y.shape)
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return x+y
|
||||
|
||||
@staticmethod
|
||||
@@ -24,7 +24,7 @@ register('add', Add)
|
||||
class Sub(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape,y.shape)
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return x-y
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -144,26 +144,37 @@ def reduce_op(ctx, code, code2, input, osize):
|
||||
prg.reduce(ctx.cl_queue, osize, None, input, i32(np.prod(input.shape) // np.prod(osize)), ret)
|
||||
return ret
|
||||
|
||||
def unbroadcast(ctx, out, in_sh):
|
||||
# TODO: write this to match unbroadcast from ops.py
|
||||
assert out.shape == in_sh
|
||||
return out
|
||||
|
||||
# ***** now for the ops themselves *****
|
||||
|
||||
class Add(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return binary_op(ctx, 'a+b', x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, grad_output
|
||||
grad_x, grad_y = grad_output, grad_output
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
|
||||
register('add', Add, gpu=True)
|
||||
|
||||
class Sub(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return binary_op(ctx, 'a-b', x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, unary_op(ctx, '-a', grad_output)
|
||||
grad_x, grad_y = grad_output, unary_op(ctx, '-a', grad_output)
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
|
||||
register('sub', Sub, gpu=True)
|
||||
|
||||
class Mul(Function):
|
||||
@@ -175,8 +186,9 @@ class Mul(Function):
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
return binary_op(ctx, 'a*b', y, grad_output),\
|
||||
binary_op(ctx, 'a*b', x, grad_output)
|
||||
grad_x = binary_op(ctx, 'a*b', y, grad_output)
|
||||
grad_y = binary_op(ctx, 'a*b', x, grad_output)
|
||||
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
|
||||
register('mul', Mul, gpu=True)
|
||||
|
||||
class Pow(Function):
|
||||
@@ -188,11 +200,11 @@ class Pow(Function):
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
gradx = binary_op(ctx, 'a*b', grad_output,
|
||||
grad_x = binary_op(ctx, 'a*b', grad_output,
|
||||
binary_op(ctx, 'b * (pow((float)a, (float)(b-1.0)))', x, y))
|
||||
grady = binary_op(ctx, 'a*b', grad_output,
|
||||
grad_y = binary_op(ctx, 'a*b', grad_output,
|
||||
binary_op(ctx, 'pow(a, (float)b) * log(a);', x, y))
|
||||
return gradx, grady
|
||||
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
|
||||
register('pow', Pow, gpu=True)
|
||||
|
||||
class Sum(Function):
|
||||
|
||||
Reference in New Issue
Block a user