mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor: move configuration and artifacts to compile and trace methods
This commit is contained in:
@@ -19,12 +19,12 @@ def test_artifacts_export(helpers):
|
||||
configuration = helpers.configuration()
|
||||
artifacts = DebugArtifacts(tmpdir)
|
||||
|
||||
@compiler({"x": "encrypted"}, configuration=configuration, artifacts=artifacts)
|
||||
@compiler({"x": "encrypted"})
|
||||
def f(x):
|
||||
return x + 10
|
||||
|
||||
inputset = range(100)
|
||||
f.compile(inputset)
|
||||
f.compile(inputset, configuration, artifacts)
|
||||
|
||||
artifacts.export()
|
||||
|
||||
|
||||
@@ -18,12 +18,12 @@ def test_circuit_str(helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
|
||||
circuit = f.compile(inputset)
|
||||
circuit = f.compile(inputset, configuration)
|
||||
|
||||
assert str(circuit) == (
|
||||
"""
|
||||
@@ -44,12 +44,12 @@ def test_circuit_draw(helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
|
||||
circuit = f.compile(inputset)
|
||||
circuit = f.compile(inputset, configuration)
|
||||
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
tmpdir = Path(path)
|
||||
@@ -67,12 +67,12 @@ def test_circuit_bad_run(helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
|
||||
circuit = f.compile(inputset)
|
||||
circuit = f.compile(inputset, configuration)
|
||||
|
||||
# with 1 argument
|
||||
# ---------------
|
||||
@@ -138,12 +138,12 @@ def test_circuit_virtual_explicit_api(helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
|
||||
circuit = f.compile(inputset, virtual=True)
|
||||
circuit = f.compile(inputset, configuration, virtual=True)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
circuit.keygen()
|
||||
|
||||
@@ -7,13 +7,11 @@ import pytest
|
||||
from concrete.numpy.compilation import Compiler
|
||||
|
||||
|
||||
def test_compiler_bad_init(helpers):
|
||||
def test_compiler_bad_init():
|
||||
"""
|
||||
Test `__init__` method of `Compiler` class with bad parameters.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
def f(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
@@ -21,7 +19,7 @@ def test_compiler_bad_init(helpers):
|
||||
# -----------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
Compiler(f, {}, configuration=configuration)
|
||||
Compiler(f, {})
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Encryption statuses of parameters 'x', 'y' and 'z' of function 'f' are not provided"
|
||||
@@ -31,7 +29,7 @@ def test_compiler_bad_init(helpers):
|
||||
# ---------------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
Compiler(f, {"z": "clear"}, configuration=configuration)
|
||||
Compiler(f, {"z": "clear"})
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Encryption statuses of parameters 'x' and 'y' of function 'f' are not provided"
|
||||
@@ -41,7 +39,7 @@ def test_compiler_bad_init(helpers):
|
||||
# ---------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
Compiler(f, {"y": "encrypted", "z": "clear"}, configuration=configuration)
|
||||
Compiler(f, {"y": "encrypted", "z": "clear"})
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Encryption status of parameter 'x' of function 'f' is not provided"
|
||||
@@ -52,29 +50,19 @@ def test_compiler_bad_init(helpers):
|
||||
|
||||
# this is fine and `p` is just ignored
|
||||
|
||||
Compiler(
|
||||
f,
|
||||
{"x": "encrypted", "y": "encrypted", "z": "clear", "p": "clear"},
|
||||
configuration=configuration,
|
||||
)
|
||||
Compiler(f, {"x": "encrypted", "y": "encrypted", "z": "clear", "p": "clear"})
|
||||
|
||||
|
||||
def test_compiler_bad_call(helpers):
|
||||
def test_compiler_bad_call():
|
||||
"""
|
||||
Test `__call__` method of `Compiler` class with bad parameters.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
def f(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compiler = Compiler(
|
||||
f,
|
||||
{"x": "encrypted", "y": "encrypted", "z": "clear"},
|
||||
configuration=configuration,
|
||||
)
|
||||
compiler = Compiler(f, {"x": "encrypted", "y": "encrypted", "z": "clear"})
|
||||
compiler(1, 2, 3, invalid=4)
|
||||
|
||||
assert str(excinfo.value) == "Calling function 'f' with kwargs is not supported"
|
||||
@@ -94,9 +82,8 @@ def test_compiler_bad_trace(helpers):
|
||||
compiler = Compiler(
|
||||
f,
|
||||
{"x": "encrypted", "y": "encrypted", "z": "clear"},
|
||||
configuration=configuration,
|
||||
)
|
||||
compiler.trace()
|
||||
compiler.trace(configuration=configuration)
|
||||
|
||||
assert str(excinfo.value) == "Tracing function 'f' without an inputset is not supported"
|
||||
|
||||
@@ -115,17 +102,18 @@ def test_compiler_bad_compile(helpers):
|
||||
compiler = Compiler(
|
||||
f,
|
||||
{"x": "encrypted", "y": "encrypted", "z": "clear"},
|
||||
configuration=configuration,
|
||||
)
|
||||
compiler.compile()
|
||||
compiler.compile(configuration=configuration)
|
||||
|
||||
assert str(excinfo.value) == "Compiling function 'f' without an inputset is not supported"
|
||||
|
||||
configuration.enable_unsafe_features = False
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compiler = Compiler(lambda x: x, {"x": "encrypted"}, configuration=configuration)
|
||||
compiler.compile(virtual=True)
|
||||
compiler = Compiler(lambda x: x, {"x": "encrypted"})
|
||||
compiler.compile(
|
||||
range(10),
|
||||
configuration.fork(enable_unsafe_features=False, use_insecure_key_cache=False),
|
||||
virtual=True,
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Virtual compilation is not allowed without enabling unsafe features"
|
||||
@@ -142,7 +130,7 @@ def test_compiler_virtual_compile(helpers):
|
||||
def f(x):
|
||||
return x + 400
|
||||
|
||||
compiler = Compiler(f, {"x": "encrypted"}, configuration=configuration)
|
||||
circuit = compiler.compile(inputset=range(400), virtual=True)
|
||||
compiler = Compiler(f, {"x": "encrypted"})
|
||||
circuit = compiler.compile(inputset=range(400), configuration=configuration, virtual=True)
|
||||
|
||||
assert circuit.encrypt_run_decrypt(200) == 600
|
||||
|
||||
@@ -12,14 +12,14 @@ def test_call_compile(helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
for i in range(10):
|
||||
function(i)
|
||||
|
||||
circuit = function.compile()
|
||||
circuit = function.compile(configuration=configuration)
|
||||
|
||||
sample = 5
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
@@ -33,12 +33,12 @@ def test_compiler_verbose_trace(helpers, capsys):
|
||||
configuration = helpers.configuration()
|
||||
artifacts = DebugArtifacts()
|
||||
|
||||
@compiler({"x": "encrypted"}, configuration=configuration, artifacts=artifacts)
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
function.trace(inputset, show_graph=True)
|
||||
function.trace(inputset, configuration, artifacts, show_graph=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == (
|
||||
@@ -61,12 +61,12 @@ def test_compiler_verbose_compile(helpers, capsys):
|
||||
configuration = helpers.configuration()
|
||||
artifacts = DebugArtifacts()
|
||||
|
||||
@compiler({"x": "encrypted"}, configuration=configuration, artifacts=artifacts)
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
function.compile(inputset, show_graph=True, show_mlir=True)
|
||||
function.compile(inputset, configuration, artifacts, show_graph=True, show_mlir=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == (
|
||||
|
||||
@@ -59,10 +59,10 @@ def test_constant_add(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
@@ -150,10 +150,10 @@ def test_add(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -167,10 +167,10 @@ def test_concatenate(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -57,12 +57,12 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers
|
||||
else:
|
||||
bias = None
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return connx.conv(x, weight, bias, strides=strides, dilations=dilations)
|
||||
|
||||
inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)]
|
||||
circuit = function.compile(inputset)
|
||||
circuit = function.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 4, size=input_shape, dtype=np.uint8)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
@@ -373,7 +373,7 @@ def test_bad_conv_compilation(
|
||||
else:
|
||||
bias = None
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return connx.conv(
|
||||
x,
|
||||
@@ -389,7 +389,7 @@ def test_bad_conv_compilation(
|
||||
|
||||
inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)]
|
||||
with pytest.raises(expected_error) as excinfo:
|
||||
function.compile(inputset)
|
||||
function.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == expected_message
|
||||
|
||||
|
||||
@@ -166,10 +166,10 @@ def test_direct_table_lookup(bits, function, helpers):
|
||||
# scalar
|
||||
# ------
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = range(2 ** bits)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = int(np.random.randint(0, 2 ** bits))
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
@@ -177,10 +177,10 @@ def test_direct_table_lookup(bits, function, helpers):
|
||||
# tensor
|
||||
# ------
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** bits, size=(3, 2), dtype=np.uint8) for _ in range(100)]
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2 ** bits, size=(3, 2), dtype=np.uint8)
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
@@ -207,10 +207,10 @@ def test_direct_multi_table_lookup(helpers):
|
||||
def function(x):
|
||||
return table[x]
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** 2, size=(3, 2), dtype=np.uint8) for _ in range(100)]
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2 ** 2, size=(3, 2), dtype=np.uint8)
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
@@ -277,22 +277,22 @@ def test_bad_direct_table_lookup(helpers):
|
||||
# compilation with float value
|
||||
# ----------------------------
|
||||
|
||||
compiler = cnp.Compiler(random_table_lookup_3b, {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(random_table_lookup_3b, {"x": "encrypted"})
|
||||
|
||||
inputset = [1.5]
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
compiler.compile(inputset)
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == "LookupTable cannot be looked up with EncryptedScalar<float64>"
|
||||
|
||||
# compilation with invalid shape
|
||||
# ------------------------------
|
||||
|
||||
compiler = cnp.Compiler(lambda x: table[x], {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(lambda x: table[x], {"x": "encrypted"})
|
||||
|
||||
inputset = [10, 5, 6, 2]
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
compiler.compile(inputset)
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"LookupTable of shape (3, 2) cannot be looked up with EncryptedScalar<uint4>"
|
||||
|
||||
@@ -22,23 +22,23 @@ def test_dot(size, helpers):
|
||||
bound = int(np.floor(np.sqrt(127 / size)))
|
||||
cst = np.random.randint(0, bound, size=(size,))
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def left_function(x):
|
||||
return np.dot(x, cst)
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def right_function(x):
|
||||
return np.dot(cst, x)
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def method(x):
|
||||
return x.dot(cst)
|
||||
|
||||
inputset = [np.random.randint(0, bound, size=(size,)) for i in range(100)]
|
||||
|
||||
left_function_circuit = left_function.compile(inputset)
|
||||
right_function_circuit = right_function.compile(inputset)
|
||||
method_circuit = method.compile(inputset)
|
||||
left_function_circuit = left_function.compile(inputset, configuration)
|
||||
right_function_circuit = right_function.compile(inputset, configuration)
|
||||
method_circuit = method.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, bound, size=(size,), dtype=np.uint8)
|
||||
|
||||
|
||||
@@ -125,29 +125,29 @@ def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
|
||||
lhs_cst = list(np.random.randint(minimum, maximum, size=lhs_shape))
|
||||
rhs_cst = list(np.random.randint(minimum, maximum, size=rhs_shape))
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def lhs_operator(x):
|
||||
return x @ rhs_cst
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def rhs_operator(x):
|
||||
return lhs_cst @ x
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def lhs_function(x):
|
||||
return np.matmul(x, rhs_cst)
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def rhs_function(x):
|
||||
return np.matmul(lhs_cst, x)
|
||||
|
||||
lhs_inputset = [np.random.randint(minimum, maximum, size=lhs_shape) for i in range(100)]
|
||||
rhs_inputset = [np.random.randint(minimum, maximum, size=rhs_shape) for i in range(100)]
|
||||
|
||||
lhs_operator_circuit = lhs_operator.compile(lhs_inputset)
|
||||
rhs_operator_circuit = rhs_operator.compile(rhs_inputset)
|
||||
lhs_function_circuit = lhs_function.compile(lhs_inputset)
|
||||
rhs_function_circuit = rhs_function.compile(rhs_inputset)
|
||||
lhs_operator_circuit = lhs_operator.compile(lhs_inputset, configuration)
|
||||
rhs_operator_circuit = rhs_operator.compile(rhs_inputset, configuration)
|
||||
lhs_function_circuit = lhs_function.compile(lhs_inputset, configuration)
|
||||
rhs_function_circuit = rhs_function.compile(rhs_inputset, configuration)
|
||||
|
||||
lhs_sample = np.random.randint(minimum, maximum, size=lhs_shape, dtype=np.uint8)
|
||||
rhs_sample = np.random.randint(minimum, maximum, size=rhs_shape, dtype=np.uint8)
|
||||
|
||||
@@ -67,10 +67,10 @@ def test_constant_mul(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -27,18 +27,18 @@ def test_neg(parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@cnp.compiler(parameter_encryption_statuses, configuration=configuration)
|
||||
@cnp.compiler(parameter_encryption_statuses)
|
||||
def operator(x):
|
||||
return -x
|
||||
|
||||
@cnp.compiler(parameter_encryption_statuses, configuration=configuration)
|
||||
@cnp.compiler(parameter_encryption_statuses)
|
||||
def function(x):
|
||||
return np.negative(x)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
|
||||
operator_circuit = operator.compile(inputset)
|
||||
function_circuit = function.compile(inputset)
|
||||
operator_circuit = operator.compile(inputset, configuration)
|
||||
function_circuit = function.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
|
||||
|
||||
@@ -447,10 +447,10 @@ def test_others(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
@@ -464,10 +464,10 @@ def test_others(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
@@ -483,13 +483,13 @@ def test_others_bad_fusing(helpers):
|
||||
# two variable inputs
|
||||
# -------------------
|
||||
|
||||
@cnp.compiler({"x": "encrypted", "y": "clear"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted", "y": "clear"})
|
||||
def function1(x, y):
|
||||
return (10 * (np.sin(x) ** 2) + 10 * (np.cos(y) ** 2)).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = [(i, i) for i in range(100)]
|
||||
function1.compile(inputset)
|
||||
function1.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
@@ -528,13 +528,13 @@ return %13
|
||||
# big intermediate constants
|
||||
# --------------------------
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function2(x):
|
||||
return (np.sin(x) * [[1, 2], [3, 4]]).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = range(100)
|
||||
function2.compile(inputset)
|
||||
function2.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
@@ -559,13 +559,13 @@ return %4
|
||||
# intermediates with different shape
|
||||
# ----------------------------------
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function3(x):
|
||||
return np.abs(np.sin(x)).reshape((2, 3)).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = [np.random.randint(0, 2 ** 7, size=(3, 2)) for _ in range(100)]
|
||||
function3.compile(inputset)
|
||||
function3.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
@@ -116,18 +116,18 @@ def test_reshape(shape, newshape, helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return np.reshape(x, newshape)
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def method(x):
|
||||
return x.reshape(newshape)
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** 5, size=shape) for i in range(100)]
|
||||
|
||||
function_circuit = function.compile(inputset)
|
||||
method_circuit = method.compile(inputset)
|
||||
function_circuit = function.compile(inputset, configuration)
|
||||
method_circuit = method.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2 ** 5, size=shape, dtype=np.uint8)
|
||||
|
||||
@@ -159,12 +159,12 @@ def test_flatten(shape, helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x.flatten()
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** 5, size=shape) for i in range(100)]
|
||||
circuit = function.compile(inputset)
|
||||
circuit = function.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2 ** 5, size=shape, dtype=np.uint8)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -154,10 +154,10 @@ def test_static_indexing(shape, function, helpers):
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** 5, size=shape) for _ in range(100)]
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2 ** 5, size=shape, dtype=np.uint8)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
@@ -173,21 +173,21 @@ def test_bad_static_indexing(helpers):
|
||||
# with float
|
||||
# ----------
|
||||
|
||||
compiler = cnp.Compiler(lambda x: x[1.5], {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(lambda x: x[1.5], {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(100)]
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
compiler.compile(inputset)
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == "Indexing with '1.5' is not supported"
|
||||
|
||||
# with bad slice
|
||||
# --------------
|
||||
|
||||
compiler = cnp.Compiler(lambda x: x[slice(1.5, 2.5, None)], {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(lambda x: x[slice(1.5, 2.5, None)], {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(100)]
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
compiler.compile(inputset)
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == "Indexing with '1.5:2.5' is not supported"
|
||||
|
||||
@@ -47,10 +47,10 @@ def test_constant_sub(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -105,10 +105,10 @@ def test_sum(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -39,10 +39,10 @@ def test_transpose(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -389,10 +389,10 @@ def test_graph_converter_bad_convert(
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
compiler = cnp.Compiler(function, encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, encryption_statuses)
|
||||
|
||||
with pytest.raises(expected_error) as excinfo:
|
||||
compiler.compile(inputset)
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(expected_message, str(excinfo.value))
|
||||
|
||||
|
||||
@@ -39,8 +39,8 @@ def test_graph_maximum_integer_bit_width(function, inputset, expected_result, he
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration=configuration)
|
||||
graph = compiler.trace(inputset)
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
graph = compiler.trace(inputset, configuration)
|
||||
|
||||
print(graph.format())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user