fix Tensor.bitwise_and and Tensor.bitwise_or to support bool (#7684)

This commit is contained in:
chenyu
2024-11-13 13:10:39 -05:00
committed by GitHub
parent 3d82f8e340
commit 3c6fe4b79a
2 changed files with 28 additions and 8 deletions

View File

@@ -541,26 +541,45 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: x.rsqrt())
def test_xor(self):
tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int)
ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32)
data = [[1,-8,1],[32,1,6]]
tor = torch.tensor(data, dtype=torch.int)
ten = Tensor(data, dtype=dtypes.int32)
helper_test_op([], lambda: tor^tor, lambda: ten^ten, forward_only=True)
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
self.helper_test_exception([(4), (4)], torch.bitwise_xor, Tensor.xor, expected=RuntimeError)
def test_and(self):
tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int)
ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32)
data = [[1,-8,1],[32,1,6]]
tor = torch.tensor(data, dtype=torch.int)
ten = Tensor(data, dtype=dtypes.int32)
helper_test_op([], lambda: tor&tor, lambda: ten&ten, forward_only=True)
helper_test_op([], lambda: tor&0x1337, lambda: ten&0x1337, forward_only=True)
helper_test_op([], lambda: 0x1337&tor, lambda: 0x1337&ten, forward_only=True)
data = [[True, True, False, False], [True, False, True, False]]
tor0, tor1 = torch.tensor(data[0], dtype=torch.bool), torch.tensor(data[1], dtype=torch.bool)
ten0, ten1 = Tensor(data[0], dtype=dtypes.bool), Tensor(data[1], dtype=dtypes.bool)
helper_test_op([], lambda: tor0&tor1, lambda: ten0&ten1, forward_only=True)
self.helper_test_exception([(4), (4)], torch.bitwise_and, Tensor.bitwise_and, expected=RuntimeError)
def test_or(self):
tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int)
ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32)
data = [[1,-8,1],[32,1,6]]
tor = torch.tensor(data, dtype=torch.int)
ten = Tensor(data, dtype=dtypes.int32)
helper_test_op([], lambda: tor|tor, lambda: ten|ten, forward_only=True)
helper_test_op([], lambda: tor|0x1337, lambda: ten|0x1337, forward_only=True)
helper_test_op([], lambda: 0x1337|tor, lambda: 0x1337|ten, forward_only=True)
data = [[True, True, False, False], [True, False, True, False]]
tor0, tor1 = torch.tensor(data[0], dtype=torch.bool), torch.tensor(data[1], dtype=torch.bool)
ten0, ten1 = Tensor(data[0], dtype=dtypes.bool), Tensor(data[1], dtype=dtypes.bool)
helper_test_op([], lambda: tor0|tor1, lambda: ten0|ten1, forward_only=True)
self.helper_test_exception([(4), (4)], torch.bitwise_or, Tensor.bitwise_or, expected=RuntimeError)
def test_lshift(self):
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
tor = torch.tensor(data, dtype=torch.int)