mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: add support for comparators when one input is a constant
closes #932 refs #751
This commit is contained in:
committed by
Benoit Chevallier
parent
c733daa78c
commit
a712b0573c
@@ -244,6 +244,36 @@ class BaseTracer(ABC):
|
||||
# cst >> x
|
||||
return self._rshift(other, self)
|
||||
|
||||
def __gt__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# x > cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x > y, "gt"
|
||||
)
|
||||
|
||||
def __ge__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# x >= cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x >= y, "ge"
|
||||
)
|
||||
|
||||
def __lt__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# x < cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x < y, "lt"
|
||||
)
|
||||
|
||||
def __le__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# x <= cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x <= y, "le"
|
||||
)
|
||||
|
||||
def __ne__(self, other: Union["BaseTracer", Any]):
|
||||
# x != cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x != y, "ne"
|
||||
)
|
||||
|
||||
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
|
||||
@@ -736,6 +736,131 @@ def test_tracing_numpy_calls(
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x > 4,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) > 4,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x < 5,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) < 5,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x <= 7,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) <= 7,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x >= 9,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) >= 9,
|
||||
)
|
||||
],
|
||||
),
|
||||
# FIXME: coming soon, #936
|
||||
# pytest.param(
|
||||
# lambda x: x == 11,
|
||||
# [
|
||||
# (
|
||||
# EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
# numpy.arange(15).reshape(3, 5),
|
||||
# numpy.arange(15).reshape(3, 5) == 11,
|
||||
# )
|
||||
# ],
|
||||
# ),
|
||||
pytest.param(
|
||||
lambda x: x != 12,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) != 12,
|
||||
)
|
||||
],
|
||||
),
|
||||
# Remove misplaced-comparison-constant because precisely, we want to be sure it works fine
|
||||
# pylint: disable=misplaced-comparison-constant
|
||||
pytest.param(
|
||||
lambda x: 4 > x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
4 > numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 5 < x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
5 < numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 7 <= x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
7 <= numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 9 >= x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
9 >= numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
# FIXME: coming soon, #936
|
||||
# pytest.param(
|
||||
# lambda x: 11 == x,
|
||||
# [
|
||||
# (
|
||||
# EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
# numpy.arange(15).reshape(3, 5),
|
||||
# 11 == numpy.arange(15).reshape(3, 5),
|
||||
# )
|
||||
# ],
|
||||
# ),
|
||||
pytest.param(
|
||||
lambda x: 12 != x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
12 != numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
# pylint: enable=misplaced-comparison-constant
|
||||
],
|
||||
)
|
||||
def test_tracing_ndarray_calls(
|
||||
|
||||
Reference in New Issue
Block a user