mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
[Feature] Added BinaryOps.AND/BinaryOps.OR (#5223)
* [Feature] Added BinaryOps.AND/BinaryOps.OR * Add: __rand__, __ror__
This commit is contained in:
@@ -26,7 +26,8 @@ if Device.DEFAULT == "LLVM":
|
||||
binary_operations.remove(operator.lt)
|
||||
binary_operations.remove(operator.eq)
|
||||
|
||||
integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor)]
|
||||
integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor), (Tensor.bitwise_and, np.bitwise_and),
|
||||
(Tensor.bitwise_or, np.bitwise_or)]
|
||||
unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (Tensor.sin, np.sin),
|
||||
(Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal)]
|
||||
|
||||
|
||||
@@ -467,6 +467,20 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
|
||||
|
||||
def test_and(self):
|
||||
tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int)
|
||||
ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32)
|
||||
helper_test_op([], lambda: tor&tor, lambda: ten&ten, forward_only=True)
|
||||
helper_test_op([], lambda: tor&0x1337, lambda: ten&0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337&tor, lambda: 0x1337&ten, forward_only=True)
|
||||
|
||||
def test_or(self):
|
||||
tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int)
|
||||
ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32)
|
||||
helper_test_op([], lambda: tor|tor, lambda: ten|ten, forward_only=True)
|
||||
helper_test_op([], lambda: tor|0x1337, lambda: ten|0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337|tor, lambda: 0x1337|ten, forward_only=True)
|
||||
|
||||
def test_lshift(self):
|
||||
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
|
||||
@@ -124,6 +124,8 @@ class TestNonFloatUOps(TestUOps):
|
||||
def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
|
||||
def test_div_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.IDIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
def test_and_int32(self): self._test_bop_fxn(BinaryOps.AND, lambda a,b: int(a)&int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_or_int32(self): self._test_bop_fxn(BinaryOps.OR, lambda a,b: int(a)|int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_mod_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.MOD,
|
||||
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
@@ -157,6 +159,8 @@ class TestBoolUOps(TestUOps):
|
||||
def test_add_bool(self): self._test_bop_bool_fxn(BinaryOps.ADD, lambda a,b: a or b)
|
||||
def test_mul_bool(self): self._test_bop_bool_fxn(BinaryOps.MUL, lambda a,b: a and b)
|
||||
def test_xor_bool(self): self._test_bop_bool_fxn(BinaryOps.XOR, lambda a,b: a != b)
|
||||
def test_and_bool(self): self._test_bop_bool_fxn(BinaryOps.AND, lambda a,b: a & b)
|
||||
def test_or_bool(self): self._test_bop_bool_fxn(BinaryOps.OR, lambda a,b: a | b)
|
||||
def test_cmpne_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPNE, lambda a,b: a != b)
|
||||
def test_cmplt_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPLT, lambda a,b: a < b)
|
||||
def test_where_bool(self): self._test_top_bool_fxn(TernaryOps.WHERE, lambda a,b,c: b if a else c)
|
||||
|
||||
Reference in New Issue
Block a user