diff --git a/test/test_ops.py b/test/test_ops.py index bcf18dd9dc..41f3b3f3fe 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -82,7 +82,6 @@ class TestOps(unittest.TestCase): helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1), device=self.device) - @cpu_only def test_max(self): helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device) helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device) diff --git a/tinygrad/ops_gpu.py b/tinygrad/ops_gpu.py index 2ec463f3b5..92dd22155e 100644 --- a/tinygrad/ops_gpu.py +++ b/tinygrad/ops_gpu.py @@ -228,6 +228,24 @@ class Sum(Function): return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True)) register('sum', Sum, device=Device.GPU) +class Max(Function): + @staticmethod + def forward(ctx, input, axis=None): + axis = [axis] if type(axis) == int else axis + ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis) + ctx.save_for_backward(input, axis, ret) + if axis is not None: + ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis]) + return ret + + @staticmethod + def backward(ctx, grad_output): + input, axis, ret = ctx.saved_tensors + shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] + ret2 = binary_op(ctx, "1.0*(a == b)", input, GPUBuffer(shape, ret)) + return binary_op(ctx, 'a*b', ret2, GPUBuffer(shape, grad_output)) +register('max', Max, device=Device.GPU) + class Dot(Function): @staticmethod def forward(ctx, input, weight): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0d2c611f18..c8c8c43c97 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -224,15 +224,15 @@ class Tensor: def softmax(self): ns = list(self.shape)[:-1]+[1] - #e = (self - self.max(axis=len(self.shape)-1).reshape(shape=ns)).exp() - e = self.exp() + m = self.max(axis=len(self.shape)-1).reshape(shape=ns) + e = (self - m).exp() ss = e.sum(axis=len(self.shape)-1).reshape(shape=ns) return e.div(ss) def logsoftmax(self): ns = list(self.shape)[:-1]+[1] - # TODO: logsumexp stability with max - ss = self.exp().sum(axis=len(self.shape)-1).reshape(shape=ns).log() + m = self.max(axis=len(self.shape)-1).reshape(shape=ns) + ss = m + (self-m).exp().sum(axis=len(self.shape)-1).reshape(shape=ns).log() return self - ss def dropout(self, p=0.5):