diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index cfd58a43f..8a37b3530 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -244,6 +244,36 @@ class BaseTracer(ABC): # cst >> x return self._rshift(other, self) + def __gt__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": + # x > cst + return self._helper_for_binary_functions_with_one_cst_input( + self, other, lambda x, y: x > y, "gt" + ) + + def __ge__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": + # x >= cst + return self._helper_for_binary_functions_with_one_cst_input( + self, other, lambda x, y: x >= y, "ge" + ) + + def __lt__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": + # x < cst + return self._helper_for_binary_functions_with_one_cst_input( + self, other, lambda x, y: x < y, "lt" + ) + + def __le__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": + # x <= cst + return self._helper_for_binary_functions_with_one_cst_input( + self, other, lambda x, y: x <= y, "le" + ) + + def __ne__(self, other: Union["BaseTracer", Any]): + # x != cst + return self._helper_for_binary_functions_with_one_cst_input( + self, other, lambda x, y: x != y, "ne" + ) + def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": if not self._supports_other_operand(other): return NotImplemented diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 9001555ed..a578c7906 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -736,6 +736,131 @@ def test_tracing_numpy_calls( ) ], ), + pytest.param( + lambda x: x > 4, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.arange(15).reshape(3, 5) > 4, + ) + ], + ), + pytest.param( + lambda x: x < 5, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.arange(15).reshape(3, 5) < 5, + ) + ], + ), + pytest.param( + lambda x: x <= 7, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.arange(15).reshape(3, 5) <= 7, + ) + ], + ), + pytest.param( + lambda x: x >= 9, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.arange(15).reshape(3, 5) >= 9, + ) + ], + ), + # FIXME: coming soon, #936 + # pytest.param( + # lambda x: x == 11, + # [ + # ( + # EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + # numpy.arange(15).reshape(3, 5), + # numpy.arange(15).reshape(3, 5) == 11, + # ) + # ], + # ), + pytest.param( + lambda x: x != 12, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.arange(15).reshape(3, 5) != 12, + ) + ], + ), + # Remove misplaced-comparison-constant because precisely, we want to be sure it works fine + # pylint: disable=misplaced-comparison-constant + pytest.param( + lambda x: 4 > x, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + 4 > numpy.arange(15).reshape(3, 5), + ) + ], + ), + pytest.param( + lambda x: 5 < x, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + 5 < numpy.arange(15).reshape(3, 5), + ) + ], + ), + pytest.param( + lambda x: 7 <= x, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + 7 <= numpy.arange(15).reshape(3, 5), + ) + ], + ), + pytest.param( + lambda x: 9 >= x, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + 9 >= numpy.arange(15).reshape(3, 5), + ) + ], + ), + # FIXME: coming soon, #936 + # pytest.param( + # lambda x: 11 == x, + # [ + # ( + # EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + # numpy.arange(15).reshape(3, 5), + # 11 == numpy.arange(15).reshape(3, 5), + # ) + # ], + # ), + pytest.param( + lambda x: 12 != x, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + 12 != numpy.arange(15).reshape(3, 5), + ) + ], + ), + # pylint: enable=misplaced-comparison-constant ], ) def test_tracing_ndarray_calls(