diff --git a/test/test_ops.py b/test/test_ops.py index 6002c16cae..ba2fa20bae 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -580,6 +580,21 @@ class TestOps(unittest.TestCase): self.helper_test_exception([(4), (4)], torch.bitwise_or, Tensor.bitwise_or, expected=RuntimeError) + def test_bitwise_not(self): + data = [[1,-8,1],[32,1,6]] + tor = torch.tensor(data, dtype=torch.int) + ten = Tensor(data, dtype=dtypes.int32) + helper_test_op([], lambda: tor.bitwise_not(), lambda: ten.bitwise_not(), forward_only=True) + helper_test_op([], lambda: ~tor, lambda: ~ten, forward_only=True) + + data = [[True, False]] + tor = torch.tensor(data, dtype=torch.bool) + ten = Tensor(data, dtype=dtypes.bool) + helper_test_op([], lambda: tor.bitwise_not(), lambda: ten.bitwise_not(), forward_only=True) + helper_test_op([], lambda: ~tor, lambda: ~ten, forward_only=True) + + self.helper_test_exception([(4)], torch.bitwise_not, Tensor.bitwise_not, expected=RuntimeError) + def test_lshift(self): data = [[0,1,2],[1<<8,1<<16,1<<31-1]] tor = torch.tensor(data, dtype=torch.int) @@ -993,6 +1008,12 @@ class TestOps(unittest.TestCase): helper_test_op([(45,3)], lambda x: x.min()) helper_test_op([(45,3)], lambda x: x.min().mul(0.5)) helper_test_op([()], lambda x: x.min()) + + helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]]) + helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]]) + helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[False, True]]) + helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[True, False]]) + def test_max(self): helper_test_op([(45,3)], lambda x: x.max()) helper_test_op([(45,3)], lambda x: x.max().mul(0.5)) @@ -1000,6 +1021,11 @@ class TestOps(unittest.TestCase): helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1)) helper_test_op([()], lambda x: x.max()) + helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]]) + helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]]) + helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[False, True]]) + helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[True, False]]) + @unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)") def test_any(self): helper_test_op([(3,4,5,6)], lambda x: x.any(), forward_only=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 96afc8a423..a17a5f42ac 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1575,8 +1575,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(t.min(axis=1, keepdim=True).numpy()) ``` """ - if dtypes.is_unsigned(self.dtype): - return dtypes.max(self.dtype) - (dtypes.max(self.dtype) - self).max(axis=axis, keepdim=keepdim) + if dtypes.is_int(self.dtype) or self.dtype == dtypes.bool: return ~((~self).max(axis=axis, keepdim=keepdim)) return -((-self).max(axis=axis, keepdim=keepdim)) def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): @@ -2982,6 +2981,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") return F.BitwiseOr.apply(*self._broadcasted(x, reverse)) + def bitwise_not(self) -> Tensor: + """ + Compute the bit-wise NOT of `self`. + Equivalent to `~self`. + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([True, False]).bitwise_not().numpy()) + ``` + """ + if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") + return self.logical_not() if self.dtype == dtypes.bool else self ^ ((1<<8*self.dtype.itemsize)-1) + def lshift(self, x:int): """ Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype. @@ -3098,6 +3111,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method # ***** op wrappers ***** + def __invert__(self) -> Tensor: return self.bitwise_not() + def __lshift__(self, x) -> Tensor: return self.lshift(x) def __rshift__(self, x) -> Tensor: return self.rshift(x)