mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-03 19:25:06 -05:00
lshift and rshift (#4591)
This commit is contained in:
@@ -437,6 +437,28 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
|
||||
|
||||
def test_lshift(self):
|
||||
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
ten = Tensor(data, dtype=dtypes.uint32)
|
||||
# cast to int32 because torch does not support uint32
|
||||
helper_test_op([], lambda: tor << 0, lambda: (ten << 0).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor << 2, lambda: (ten << 2).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor << 31, lambda: (ten << 31).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor.__lshift__(2), lambda: ten.__lshift__(2).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor.bitwise_left_shift(2), lambda: ten.lshift(2).cast(dtypes.int32), forward_only=True)
|
||||
|
||||
def test_rshift(self):
|
||||
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
ten = Tensor(data, dtype=dtypes.uint32)
|
||||
# cast to int32 because torch does not support uint32
|
||||
helper_test_op([], lambda: tor >> 0, lambda: (ten >> 0).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor >> 2, lambda: (ten >> 2).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor >> 31, lambda: (ten >> 31).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor.__rshift__(2), lambda: ten.__rshift__(2).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor.bitwise_right_shift(2), lambda: ten.rshift(2).cast(dtypes.int32), forward_only=True)
|
||||
|
||||
def test_sin(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sin())
|
||||
helper_test_op([()], lambda x: x.sin())
|
||||
|
||||
Reference in New Issue
Block a user