lshift and rshift (#4591)

This commit is contained in:
chenyu
2024-05-14 19:16:31 -04:00
committed by GitHub
parent 45e7400e3c
commit 2b0ee74bb6
3 changed files with 38 additions and 4 deletions

View File

@@ -437,6 +437,28 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
def test_lshift(self):
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
tor = torch.tensor(data, dtype=torch.int)
ten = Tensor(data, dtype=dtypes.uint32)
# cast to int32 because torch does not support uint32
helper_test_op([], lambda: tor << 0, lambda: (ten << 0).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor << 2, lambda: (ten << 2).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor << 31, lambda: (ten << 31).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor.__lshift__(2), lambda: ten.__lshift__(2).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor.bitwise_left_shift(2), lambda: ten.lshift(2).cast(dtypes.int32), forward_only=True)
def test_rshift(self):
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
tor = torch.tensor(data, dtype=torch.int)
ten = Tensor(data, dtype=dtypes.uint32)
# cast to int32 because torch does not support uint32
helper_test_op([], lambda: tor >> 0, lambda: (ten >> 0).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor >> 2, lambda: (ten >> 2).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor >> 31, lambda: (ten >> 31).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor.__rshift__(2), lambda: ten.__rshift__(2).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor.bitwise_right_shift(2), lambda: ten.rshift(2).cast(dtypes.int32), forward_only=True)
def test_sin(self):
helper_test_op([(45,65)], lambda x: x.sin())
helper_test_op([()], lambda x: x.sin())