feat(tracing): add support for more arithmetic operators, when one input is a constant

refs #218
refs #751
This commit is contained in:
Benoit Chevallier-Mames
2021-11-18 09:53:53 +01:00
committed by Benoit Chevallier
parent 507ccd05c5
commit 4c6e1661ec
2 changed files with 160 additions and 0 deletions

View File

@@ -302,6 +302,66 @@ class BaseTracer(ABC):
self, other, lambda x, y: x != y, "ne"
)
def __pow__(self, other: Union["BaseTracer", Any]):
# x ** cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x ** y, "pow"
)
def __rpow__(self, other: Union["BaseTracer", Any]):
# cst ** x
return self._helper_for_binary_functions_with_one_cst_input(
other, self, lambda x, y: x ** y, "pow"
)
def __mod__(self, other: Union["BaseTracer", Any]):
# x % cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x % y, "mod"
)
def __rmod__(self, other: Union["BaseTracer", Any]):
# cst % x
return self._helper_for_binary_functions_with_one_cst_input(
other, self, lambda x, y: x % y, "mod"
)
def __and__(self, other: Union["BaseTracer", Any]):
# x & cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x & y, "and"
)
def __rand__(self, other: Union["BaseTracer", Any]):
# cst & x
return self._helper_for_binary_functions_with_one_cst_input(
other, self, lambda x, y: x & y, "and"
)
def __or__(self, other: Union["BaseTracer", Any]):
# x | cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x | y, "or"
)
def __ror__(self, other: Union["BaseTracer", Any]):
# cst | x
return self._helper_for_binary_functions_with_one_cst_input(
other, self, lambda x, y: x | y, "or"
)
def __xor__(self, other: Union["BaseTracer", Any]):
# x ^ cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x ^ y, "xor"
)
def __rxor__(self, other: Union["BaseTracer", Any]):
# cst ^ x
return self._helper_for_binary_functions_with_one_cst_input(
other, self, lambda x, y: x ^ y, "xor"
)
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented

View File

@@ -859,6 +859,106 @@ def test_tracing_numpy_calls(
],
),
# pylint: enable=misplaced-comparison-constant
(
lambda x: x & 11,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i & 11 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: 13 & x,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i & 13 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: x | 6,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i | 6 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: 30 | x,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i | 30 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: x ^ 91,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i ^ 91 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: 115 ^ x,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i ^ 115 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: x % 11,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i % 11 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: 150 % (x + 1),
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([150 % (i + 1) for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: x ** 2,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i ** 2 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: 2 ** x,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5) % 7,
numpy.array([2 ** (i % 7) for i in range(15)]).reshape(3, 5),
),
],
),
],
)
def test_tracing_ndarray_calls(