mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix max op, less lines
This commit is contained in:
@@ -105,12 +105,12 @@ Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doi
|
|||||||
|
|
||||||
### Adding an accelerator
|
### Adding an accelerator
|
||||||
|
|
||||||
You need to support 14 basic ops:
|
You need to support 15 basic ops:
|
||||||
|
|
||||||
```
|
```
|
||||||
Add, Sub, Mul, Pow # binary ops
|
Add, Sub, Mul, Pow # binary ops
|
||||||
Relu, Log, Exp # unary ops
|
Relu, Log, Exp # unary ops
|
||||||
Sum # reduce op
|
Sum, Max # reduce ops
|
||||||
Dot # matrix multiplication
|
Dot # matrix multiplication
|
||||||
Conv2D, MaxPool2D # 2D ops
|
Conv2D, MaxPool2D # 2D ops
|
||||||
Pad2D, Reshape, Transpose # moving things around ops
|
Pad2D, Reshape, Transpose # moving things around ops
|
||||||
|
|||||||
@@ -83,9 +83,11 @@ class TestOps(unittest.TestCase):
|
|||||||
@cpu_only
|
@cpu_only
|
||||||
def test_max(self):
|
def test_max(self):
|
||||||
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device)
|
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device)
|
||||||
|
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device)
|
||||||
@cpu_only
|
@cpu_only
|
||||||
def test_max_axis(self):
|
def test_max_axis(self):
|
||||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
|
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
|
||||||
|
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0].mul(0.5), lambda x: Tensor.max(x, axis=1).mul(0.5), device=self.device)
|
||||||
def test_sum_axis(self):
|
def test_sum_axis(self):
|
||||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)
|
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)
|
||||||
def test_mean_axis(self):
|
def test_mean_axis(self):
|
||||||
|
|||||||
@@ -75,10 +75,7 @@ class Max(Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, axis=None):
|
def forward(ctx, input, axis=None):
|
||||||
am = input.argmax(axis=axis)
|
am = input.argmax(axis=axis)
|
||||||
if axis is not None:
|
am = np.expand_dims(am, axis=axis) if axis is not None else np.array([am])
|
||||||
am = np.expand_dims(am, axis=axis)
|
|
||||||
else:
|
|
||||||
am = np.array([am])
|
|
||||||
ctx.save_for_backward(input.shape, am, axis)
|
ctx.save_for_backward(input.shape, am, axis)
|
||||||
return np.take_along_axis(input, am, axis=axis).squeeze(axis=axis)
|
return np.take_along_axis(input, am, axis=axis).squeeze(axis=axis)
|
||||||
|
|
||||||
@@ -86,7 +83,7 @@ class Max(Function):
|
|||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
shape, am, axis = ctx.saved_tensors
|
shape, am, axis = ctx.saved_tensors
|
||||||
ret = np.zeros(shape, dtype=np.float32)
|
ret = np.zeros(shape, dtype=np.float32)
|
||||||
np.put_along_axis(ret, am, 1/np.prod(am.shape), axis=axis)
|
np.put_along_axis(ret, am, grad_output.reshape(am.shape), axis=axis)
|
||||||
return ret
|
return ret
|
||||||
register('max', Max)
|
register('max', Max)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user