mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: support all kind of subtraction operations
This commit is contained in:
@@ -137,8 +137,6 @@ class GraphConverter:
|
||||
|
||||
elif name == "subtract":
|
||||
assert_that(len(inputs) == 2)
|
||||
if not (inputs[0].is_clear and inputs[1].is_encrypted):
|
||||
return "only subtraction of encrypted from clear is supported"
|
||||
|
||||
elif name == "sum":
|
||||
assert_that(len(inputs) == 1)
|
||||
|
||||
@@ -840,10 +840,23 @@ class NodeConverter:
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
preds = self.preds
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.SubIntEintOp(resulting_type, *preds).result
|
||||
if self.all_of_the_inputs_are_encrypted:
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.SubEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhe.SubEintOp(resulting_type, *preds).result
|
||||
|
||||
elif self.node.inputs[0].is_clear:
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.SubIntEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhe.SubIntEintOp(resulting_type, *preds).result
|
||||
|
||||
else:
|
||||
result = fhe.SubIntEintOp(resulting_type, *preds).result
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.SubEintIntOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhe.SubEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -11,14 +11,26 @@ import concrete.numpy as cnp
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x - 42,
|
||||
id="x - 42",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 42 - x,
|
||||
id="42 - x",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x - np.array([1, 2, 3]),
|
||||
id="x - [1, 2, 3]",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.array([1, 2, 3]) - x,
|
||||
id="[1, 2, 3] - x",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x - np.array([[1, 2, 3], [4, 5, 6]]),
|
||||
id="x - [[1, 2, 3], [4, 5, 6]]",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.array([[1, 2, 3], [4, 5, 6]]) - x,
|
||||
id="[[1, 2, 3], [4, 5, 6]] - x",
|
||||
@@ -54,3 +66,94 @@ def test_constant_sub(function, parameters, helpers):
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x - y,
|
||||
id="x - y",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"parameters",
|
||||
[
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "clear"},
|
||||
"y": {"range": [0, 60], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "encrypted"},
|
||||
"y": {"range": [0, 60], "status": "clear"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "encrypted"},
|
||||
"y": {"range": [0, 60], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "clear", "shape": (3,)},
|
||||
"y": {"range": [0, 60], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
|
||||
"y": {"range": [0, 60], "status": "clear"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
|
||||
"y": {"range": [0, 60], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "clear"},
|
||||
"y": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "encrypted"},
|
||||
"y": {"range": [0, 60], "status": "clear", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "encrypted"},
|
||||
"y": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "clear", "shape": (3,)},
|
||||
"y": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
|
||||
"y": {"range": [0, 60], "status": "clear", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
|
||||
"y": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "clear", "shape": (2, 1)},
|
||||
"y": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "encrypted", "shape": (2, 1)},
|
||||
"y": {"range": [0, 60], "status": "clear", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 60], "status": "encrypted", "shape": (2, 1)},
|
||||
"y": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_sub(function, parameters, helpers):
|
||||
"""
|
||||
Test sub where both of the operators are dynamic.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -286,23 +286,6 @@ return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x - 1,
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2**3, size=(2, 3)) for _ in range(100)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint3, shape=(2, 3)>
|
||||
%1 = 1 # ClearScalar<uint1>
|
||||
%2 = subtract(%0, %1) # ClearTensor<int4, shape=(2, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only subtraction of encrypted from clear is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.sum(x),
|
||||
{"x": "clear"},
|
||||
|
||||
Reference in New Issue
Block a user