mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(tracing): add support for more arithmetic operators, when one input is a constant
refs #218 refs #751
This commit is contained in:
committed by
Benoit Chevallier
parent
507ccd05c5
commit
4c6e1661ec
@@ -302,6 +302,66 @@ class BaseTracer(ABC):
|
||||
self, other, lambda x, y: x != y, "ne"
|
||||
)
|
||||
|
||||
def __pow__(self, other: Union["BaseTracer", Any]):
|
||||
# x ** cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x ** y, "pow"
|
||||
)
|
||||
|
||||
def __rpow__(self, other: Union["BaseTracer", Any]):
|
||||
# cst ** x
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
other, self, lambda x, y: x ** y, "pow"
|
||||
)
|
||||
|
||||
def __mod__(self, other: Union["BaseTracer", Any]):
|
||||
# x % cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x % y, "mod"
|
||||
)
|
||||
|
||||
def __rmod__(self, other: Union["BaseTracer", Any]):
|
||||
# cst % x
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
other, self, lambda x, y: x % y, "mod"
|
||||
)
|
||||
|
||||
def __and__(self, other: Union["BaseTracer", Any]):
|
||||
# x & cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x & y, "and"
|
||||
)
|
||||
|
||||
def __rand__(self, other: Union["BaseTracer", Any]):
|
||||
# cst & x
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
other, self, lambda x, y: x & y, "and"
|
||||
)
|
||||
|
||||
def __or__(self, other: Union["BaseTracer", Any]):
|
||||
# x | cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x | y, "or"
|
||||
)
|
||||
|
||||
def __ror__(self, other: Union["BaseTracer", Any]):
|
||||
# cst | x
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
other, self, lambda x, y: x | y, "or"
|
||||
)
|
||||
|
||||
def __xor__(self, other: Union["BaseTracer", Any]):
|
||||
# x ^ cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x ^ y, "xor"
|
||||
)
|
||||
|
||||
def __rxor__(self, other: Union["BaseTracer", Any]):
|
||||
# cst ^ x
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
other, self, lambda x, y: x ^ y, "xor"
|
||||
)
|
||||
|
||||
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
|
||||
@@ -859,6 +859,106 @@ def test_tracing_numpy_calls(
|
||||
],
|
||||
),
|
||||
# pylint: enable=misplaced-comparison-constant
|
||||
(
|
||||
lambda x: x & 11,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i & 11 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: 13 & x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i & 13 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x | 6,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i | 6 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: 30 | x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i | 30 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x ^ 91,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i ^ 91 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: 115 ^ x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i ^ 115 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x % 11,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i % 11 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: 150 % (x + 1),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([150 % (i + 1) for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x ** 2,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i ** 2 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: 2 ** x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5) % 7,
|
||||
numpy.array([2 ** (i % 7) for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_ndarray_calls(
|
||||
|
||||
Reference in New Issue
Block a user