mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Tensor.bitwise_not (#7688)
implemented with xor in tensor for now to not add another op. also used it in Tensor.min to fix dtype int on -2**31
This commit is contained in:
@@ -580,6 +580,21 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
self.helper_test_exception([(4), (4)], torch.bitwise_or, Tensor.bitwise_or, expected=RuntimeError)
|
||||
|
||||
def test_bitwise_not(self):
|
||||
data = [[1,-8,1],[32,1,6]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
ten = Tensor(data, dtype=dtypes.int32)
|
||||
helper_test_op([], lambda: tor.bitwise_not(), lambda: ten.bitwise_not(), forward_only=True)
|
||||
helper_test_op([], lambda: ~tor, lambda: ~ten, forward_only=True)
|
||||
|
||||
data = [[True, False]]
|
||||
tor = torch.tensor(data, dtype=torch.bool)
|
||||
ten = Tensor(data, dtype=dtypes.bool)
|
||||
helper_test_op([], lambda: tor.bitwise_not(), lambda: ten.bitwise_not(), forward_only=True)
|
||||
helper_test_op([], lambda: ~tor, lambda: ~ten, forward_only=True)
|
||||
|
||||
self.helper_test_exception([(4)], torch.bitwise_not, Tensor.bitwise_not, expected=RuntimeError)
|
||||
|
||||
def test_lshift(self):
|
||||
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
@@ -993,6 +1008,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,3)], lambda x: x.min())
|
||||
helper_test_op([(45,3)], lambda x: x.min().mul(0.5))
|
||||
helper_test_op([()], lambda x: x.min())
|
||||
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
def test_max(self):
|
||||
helper_test_op([(45,3)], lambda x: x.max())
|
||||
helper_test_op([(45,3)], lambda x: x.max().mul(0.5))
|
||||
@@ -1000,6 +1021,11 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
|
||||
helper_test_op([()], lambda x: x.max())
|
||||
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
|
||||
def test_any(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.any(), forward_only=True)
|
||||
|
||||
@@ -1575,8 +1575,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
print(t.min(axis=1, keepdim=True).numpy())
|
||||
```
|
||||
"""
|
||||
if dtypes.is_unsigned(self.dtype):
|
||||
return dtypes.max(self.dtype) - (dtypes.max(self.dtype) - self).max(axis=axis, keepdim=keepdim)
|
||||
if dtypes.is_int(self.dtype) or self.dtype == dtypes.bool: return ~((~self).max(axis=axis, keepdim=keepdim))
|
||||
return -((-self).max(axis=axis, keepdim=keepdim))
|
||||
|
||||
def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
||||
@@ -2982,6 +2981,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
|
||||
|
||||
def bitwise_not(self) -> Tensor:
|
||||
"""
|
||||
Compute the bit-wise NOT of `self`.
|
||||
Equivalent to `~self`.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([True, False]).bitwise_not().numpy())
|
||||
```
|
||||
"""
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return self.logical_not() if self.dtype == dtypes.bool else self ^ ((1<<8*self.dtype.itemsize)-1)
|
||||
|
||||
def lshift(self, x:int):
|
||||
"""
|
||||
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
||||
@@ -3098,6 +3111,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
|
||||
# ***** op wrappers *****
|
||||
|
||||
def __invert__(self) -> Tensor: return self.bitwise_not()
|
||||
|
||||
def __lshift__(self, x) -> Tensor: return self.lshift(x)
|
||||
def __rshift__(self, x) -> Tensor: return self.rshift(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user