Tensor.bitwise_not (#7688)

implemented with xor in tensor for now to not add another op. also used it in Tensor.min to fix dtype int on -2**31
This commit is contained in:
chenyu
2024-11-13 16:31:52 -05:00
committed by GitHub
parent 0423db8d00
commit 333f5f9f8b
2 changed files with 43 additions and 2 deletions

View File

@@ -580,6 +580,21 @@ class TestOps(unittest.TestCase):
self.helper_test_exception([(4), (4)], torch.bitwise_or, Tensor.bitwise_or, expected=RuntimeError)
def test_bitwise_not(self):
data = [[1,-8,1],[32,1,6]]
tor = torch.tensor(data, dtype=torch.int)
ten = Tensor(data, dtype=dtypes.int32)
helper_test_op([], lambda: tor.bitwise_not(), lambda: ten.bitwise_not(), forward_only=True)
helper_test_op([], lambda: ~tor, lambda: ~ten, forward_only=True)
data = [[True, False]]
tor = torch.tensor(data, dtype=torch.bool)
ten = Tensor(data, dtype=dtypes.bool)
helper_test_op([], lambda: tor.bitwise_not(), lambda: ten.bitwise_not(), forward_only=True)
helper_test_op([], lambda: ~tor, lambda: ~ten, forward_only=True)
self.helper_test_exception([(4)], torch.bitwise_not, Tensor.bitwise_not, expected=RuntimeError)
def test_lshift(self):
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
tor = torch.tensor(data, dtype=torch.int)
@@ -993,6 +1008,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,3)], lambda x: x.min())
helper_test_op([(45,3)], lambda x: x.min().mul(0.5))
helper_test_op([()], lambda x: x.min())
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[True, False]])
def test_max(self):
helper_test_op([(45,3)], lambda x: x.max())
helper_test_op([(45,3)], lambda x: x.max().mul(0.5))
@@ -1000,6 +1021,11 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
helper_test_op([()], lambda x: x.max())
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[True, False]])
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
def test_any(self):
helper_test_op([(3,4,5,6)], lambda x: x.any(), forward_only=True)

View File

@@ -1575,8 +1575,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
print(t.min(axis=1, keepdim=True).numpy())
```
"""
if dtypes.is_unsigned(self.dtype):
return dtypes.max(self.dtype) - (dtypes.max(self.dtype) - self).max(axis=axis, keepdim=keepdim)
if dtypes.is_int(self.dtype) or self.dtype == dtypes.bool: return ~((~self).max(axis=axis, keepdim=keepdim))
return -((-self).max(axis=axis, keepdim=keepdim))
def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
@@ -2982,6 +2981,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
def bitwise_not(self) -> Tensor:
"""
Compute the bit-wise NOT of `self`.
Equivalent to `~self`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([True, False]).bitwise_not().numpy())
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return self.logical_not() if self.dtype == dtypes.bool else self ^ ((1<<8*self.dtype.itemsize)-1)
def lshift(self, x:int):
"""
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
@@ -3098,6 +3111,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
# ***** op wrappers *****
def __invert__(self) -> Tensor: return self.bitwise_not()
def __lshift__(self, x) -> Tensor: return self.lshift(x)
def __rshift__(self, x) -> Tensor: return self.rshift(x)