Files
concrete/frontends/concrete-python/tests/execution/test_sum.py

133 lines
3.8 KiB
Python

"""
Tests of execution of sum operation.
"""
import numpy as np
import pytest
from concrete import fhe
@pytest.mark.parametrize(
"function,parameters",
[
pytest.param(
lambda x: np.sum(x),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, 0),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, 1),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=None), # type: ignore
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=0),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=1),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=-1),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=-2),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=(0, 1)),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=(-2, -1)),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, keepdims=True),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=0, keepdims=True),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=1, keepdims=True),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=-1, keepdims=True),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=-2, keepdims=True),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=(0, 1), keepdims=True),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.sum(x, axis=(-2, -1), keepdims=True),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
],
)
def test_sum(function, parameters, helpers):
"""
Test sum.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)