add GPU max thanks to marcelbischoff

This commit is contained in:
George Hotz
2020-12-29 16:44:14 -05:00
parent 4bbad11afe
commit 27208d729b
3 changed files with 22 additions and 5 deletions

View File

@@ -82,7 +82,6 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device)
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1), device=self.device)
@cpu_only
def test_max(self):
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device)
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device)

View File

@@ -228,6 +228,24 @@ class Sum(Function):
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True))
register('sum', Sum, device=Device.GPU)
class Max(Function):
@staticmethod
def forward(ctx, input, axis=None):
axis = [axis] if type(axis) == int else axis
ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis)
ctx.save_for_backward(input, axis, ret)
if axis is not None:
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
return ret
@staticmethod
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
ret2 = binary_op(ctx, "1.0*(a == b)", input, GPUBuffer(shape, ret))
return binary_op(ctx, 'a*b', ret2, GPUBuffer(shape, grad_output))
register('max', Max, device=Device.GPU)
class Dot(Function):
@staticmethod
def forward(ctx, input, weight):

View File

@@ -224,15 +224,15 @@ class Tensor:
def softmax(self):
ns = list(self.shape)[:-1]+[1]
#e = (self - self.max(axis=len(self.shape)-1).reshape(shape=ns)).exp()
e = self.exp()
m = self.max(axis=len(self.shape)-1).reshape(shape=ns)
e = (self - m).exp()
ss = e.sum(axis=len(self.shape)-1).reshape(shape=ns)
return e.div(ss)
def logsoftmax(self):
ns = list(self.shape)[:-1]+[1]
# TODO: logsumexp stability with max
ss = self.exp().sum(axis=len(self.shape)-1).reshape(shape=ns).log()
m = self.max(axis=len(self.shape)-1).reshape(shape=ns)
ss = m + (self-m).exp().sum(axis=len(self.shape)-1).reshape(shape=ns).log()
return self - ss
def dropout(self, p=0.5):