diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index 6bfcd410f..4b183e882 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -302,6 +302,66 @@ class BaseTracer(ABC): self, other, lambda x, y: x != y, "ne" ) + def __pow__(self, other: Union["BaseTracer", Any]): + # x ** cst + return self._helper_for_binary_functions_with_one_cst_input( + self, other, lambda x, y: x ** y, "pow" + ) + + def __rpow__(self, other: Union["BaseTracer", Any]): + # cst ** x + return self._helper_for_binary_functions_with_one_cst_input( + other, self, lambda x, y: x ** y, "pow" + ) + + def __mod__(self, other: Union["BaseTracer", Any]): + # x % cst + return self._helper_for_binary_functions_with_one_cst_input( + self, other, lambda x, y: x % y, "mod" + ) + + def __rmod__(self, other: Union["BaseTracer", Any]): + # cst % x + return self._helper_for_binary_functions_with_one_cst_input( + other, self, lambda x, y: x % y, "mod" + ) + + def __and__(self, other: Union["BaseTracer", Any]): + # x & cst + return self._helper_for_binary_functions_with_one_cst_input( + self, other, lambda x, y: x & y, "and" + ) + + def __rand__(self, other: Union["BaseTracer", Any]): + # cst & x + return self._helper_for_binary_functions_with_one_cst_input( + other, self, lambda x, y: x & y, "and" + ) + + def __or__(self, other: Union["BaseTracer", Any]): + # x | cst + return self._helper_for_binary_functions_with_one_cst_input( + self, other, lambda x, y: x | y, "or" + ) + + def __ror__(self, other: Union["BaseTracer", Any]): + # cst | x + return self._helper_for_binary_functions_with_one_cst_input( + other, self, lambda x, y: x | y, "or" + ) + + def __xor__(self, other: Union["BaseTracer", Any]): + # x ^ cst + return self._helper_for_binary_functions_with_one_cst_input( + self, other, lambda x, y: x ^ y, "xor" + ) + + def __rxor__(self, other: Union["BaseTracer", Any]): + # cst ^ x + return self._helper_for_binary_functions_with_one_cst_input( + other, self, lambda x, y: x ^ y, "xor" + ) + def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": if not self._supports_other_operand(other): return NotImplemented diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index df2daaa12..32403c6d5 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -859,6 +859,106 @@ def test_tracing_numpy_calls( ], ), # pylint: enable=misplaced-comparison-constant + ( + lambda x: x & 11, + [ + ( + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.array([i & 11 for i in range(15)]).reshape(3, 5), + ), + ], + ), + ( + lambda x: 13 & x, + [ + ( + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.array([i & 13 for i in range(15)]).reshape(3, 5), + ), + ], + ), + ( + lambda x: x | 6, + [ + ( + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.array([i | 6 for i in range(15)]).reshape(3, 5), + ), + ], + ), + ( + lambda x: 30 | x, + [ + ( + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.array([i | 30 for i in range(15)]).reshape(3, 5), + ), + ], + ), + ( + lambda x: x ^ 91, + [ + ( + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.array([i ^ 91 for i in range(15)]).reshape(3, 5), + ), + ], + ), + ( + lambda x: 115 ^ x, + [ + ( + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.array([i ^ 115 for i in range(15)]).reshape(3, 5), + ), + ], + ), + ( + lambda x: x % 11, + [ + ( + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.array([i % 11 for i in range(15)]).reshape(3, 5), + ), + ], + ), + ( + lambda x: 150 % (x + 1), + [ + ( + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.array([150 % (i + 1) for i in range(15)]).reshape(3, 5), + ), + ], + ), + ( + lambda x: x ** 2, + [ + ( + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.array([i ** 2 for i in range(15)]).reshape(3, 5), + ), + ], + ), + ( + lambda x: 2 ** x, + [ + ( + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + numpy.arange(15).reshape(3, 5) % 7, + numpy.array([2 ** (i % 7) for i in range(15)]).reshape(3, 5), + ), + ], + ), ], ) def test_tracing_ndarray_calls(