diff --git a/README.md b/README.md index a7d05e2bfe..a1e79c9a14 100644 --- a/README.md +++ b/README.md @@ -105,12 +105,12 @@ Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doi ### Adding an accelerator -You need to support 14 basic ops: +You need to support 15 basic ops: ``` Add, Sub, Mul, Pow # binary ops Relu, Log, Exp # unary ops -Sum # reduce op +Sum, Max # reduce ops Dot # matrix multiplication Conv2D, MaxPool2D # 2D ops Pad2D, Reshape, Transpose # moving things around ops diff --git a/test/test_ops.py b/test/test_ops.py index 457f487380..db9e370f14 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -83,9 +83,11 @@ class TestOps(unittest.TestCase): @cpu_only def test_max(self): helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device) + helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device) @cpu_only def test_max_axis(self): helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device) + helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0].mul(0.5), lambda x: Tensor.max(x, axis=1).mul(0.5), device=self.device) def test_sum_axis(self): helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device) def test_mean_axis(self): diff --git a/tinygrad/ops_cpu.py b/tinygrad/ops_cpu.py index d8e4485e18..81ac437265 100644 --- a/tinygrad/ops_cpu.py +++ b/tinygrad/ops_cpu.py @@ -75,10 +75,7 @@ class Max(Function): @staticmethod def forward(ctx, input, axis=None): am = input.argmax(axis=axis) - if axis is not None: - am = np.expand_dims(am, axis=axis) - else: - am = np.array([am]) + am = np.expand_dims(am, axis=axis) if axis is not None else np.array([am]) ctx.save_for_backward(input.shape, am, axis) return np.take_along_axis(input, am, axis=axis).squeeze(axis=axis) @@ -86,7 +83,7 @@ class Max(Function): def backward(ctx, grad_output): shape, am, axis = ctx.saved_tensors ret = np.zeros(shape, dtype=np.float32) - np.put_along_axis(ret, am, 1/np.prod(am.shape), axis=axis) + np.put_along_axis(ret, am, grad_output.reshape(am.shape), axis=axis) return ret register('max', Max)