mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
fix Tensor.bitwise_and and Tensor.bitwise_or to support bool (#7684)
This commit is contained in:
@@ -541,26 +541,45 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([()], lambda x: x.rsqrt())
|
||||
|
||||
def test_xor(self):
|
||||
tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int)
|
||||
ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32)
|
||||
data = [[1,-8,1],[32,1,6]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
ten = Tensor(data, dtype=dtypes.int32)
|
||||
helper_test_op([], lambda: tor^tor, lambda: ten^ten, forward_only=True)
|
||||
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
|
||||
|
||||
self.helper_test_exception([(4), (4)], torch.bitwise_xor, Tensor.xor, expected=RuntimeError)
|
||||
|
||||
def test_and(self):
|
||||
tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int)
|
||||
ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32)
|
||||
data = [[1,-8,1],[32,1,6]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
ten = Tensor(data, dtype=dtypes.int32)
|
||||
helper_test_op([], lambda: tor&tor, lambda: ten&ten, forward_only=True)
|
||||
helper_test_op([], lambda: tor&0x1337, lambda: ten&0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337&tor, lambda: 0x1337&ten, forward_only=True)
|
||||
|
||||
data = [[True, True, False, False], [True, False, True, False]]
|
||||
tor0, tor1 = torch.tensor(data[0], dtype=torch.bool), torch.tensor(data[1], dtype=torch.bool)
|
||||
ten0, ten1 = Tensor(data[0], dtype=dtypes.bool), Tensor(data[1], dtype=dtypes.bool)
|
||||
helper_test_op([], lambda: tor0&tor1, lambda: ten0&ten1, forward_only=True)
|
||||
|
||||
self.helper_test_exception([(4), (4)], torch.bitwise_and, Tensor.bitwise_and, expected=RuntimeError)
|
||||
|
||||
def test_or(self):
|
||||
tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int)
|
||||
ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32)
|
||||
data = [[1,-8,1],[32,1,6]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
ten = Tensor(data, dtype=dtypes.int32)
|
||||
helper_test_op([], lambda: tor|tor, lambda: ten|ten, forward_only=True)
|
||||
helper_test_op([], lambda: tor|0x1337, lambda: ten|0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337|tor, lambda: 0x1337|ten, forward_only=True)
|
||||
|
||||
data = [[True, True, False, False], [True, False, True, False]]
|
||||
tor0, tor1 = torch.tensor(data[0], dtype=torch.bool), torch.tensor(data[1], dtype=torch.bool)
|
||||
ten0, ten1 = Tensor(data[0], dtype=dtypes.bool), Tensor(data[1], dtype=dtypes.bool)
|
||||
helper_test_op([], lambda: tor0|tor1, lambda: ten0|ten1, forward_only=True)
|
||||
|
||||
self.helper_test_exception([(4), (4)], torch.bitwise_or, Tensor.bitwise_or, expected=RuntimeError)
|
||||
|
||||
def test_lshift(self):
|
||||
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
|
||||
Reference in New Issue
Block a user