feat: add support for comparators when one input is a constant

closes #932
refs #751
This commit is contained in:
Benoit Chevallier-Mames
2021-11-17 12:58:37 +01:00
committed by Benoit Chevallier
parent c733daa78c
commit a712b0573c
2 changed files with 155 additions and 0 deletions

View File

@@ -244,6 +244,36 @@ class BaseTracer(ABC):
# cst >> x
return self._rshift(other, self)
def __gt__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# x > cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x > y, "gt"
)
def __ge__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# x >= cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x >= y, "ge"
)
def __lt__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# x < cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x < y, "lt"
)
def __le__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# x <= cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x <= y, "le"
)
def __ne__(self, other: Union["BaseTracer", Any]):
# x != cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x != y, "ne"
)
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented

View File

@@ -736,6 +736,131 @@ def test_tracing_numpy_calls(
)
],
),
pytest.param(
lambda x: x > 4,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) > 4,
)
],
),
pytest.param(
lambda x: x < 5,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) < 5,
)
],
),
pytest.param(
lambda x: x <= 7,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) <= 7,
)
],
),
pytest.param(
lambda x: x >= 9,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) >= 9,
)
],
),
# FIXME: coming soon, #936
# pytest.param(
# lambda x: x == 11,
# [
# (
# EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
# numpy.arange(15).reshape(3, 5),
# numpy.arange(15).reshape(3, 5) == 11,
# )
# ],
# ),
pytest.param(
lambda x: x != 12,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) != 12,
)
],
),
# Remove misplaced-comparison-constant because precisely, we want to be sure it works fine
# pylint: disable=misplaced-comparison-constant
pytest.param(
lambda x: 4 > x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
4 > numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: 5 < x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
5 < numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: 7 <= x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
7 <= numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: 9 >= x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
9 >= numpy.arange(15).reshape(3, 5),
)
],
),
# FIXME: coming soon, #936
# pytest.param(
# lambda x: 11 == x,
# [
# (
# EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
# numpy.arange(15).reshape(3, 5),
# 11 == numpy.arange(15).reshape(3, 5),
# )
# ],
# ),
pytest.param(
lambda x: 12 != x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
12 != numpy.arange(15).reshape(3, 5),
)
],
),
# pylint: enable=misplaced-comparison-constant
],
)
def test_tracing_ndarray_calls(