refactor(frontend-python): use @fhe.compiler instead of importing it and doing @compiler

This commit is contained in:
Umut
2023-07-24 14:57:54 +02:00
parent 885d25424d
commit e49c16873c

View File

@@ -8,7 +8,8 @@ from pathlib import Path
import numpy as np
import pytest
from concrete.fhe import Client, ClientSpecs, EvaluationKeys, LookupTable, Server, Value, compiler
from concrete import fhe
from concrete.fhe import Client, ClientSpecs, EvaluationKeys, LookupTable, Server, Value
def test_circuit_str(helpers):
@@ -18,7 +19,7 @@ def test_circuit_str(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
@@ -38,7 +39,7 @@ def test_circuit_feedback(helpers):
p_error = 0.1
global_p_error = 0.05
@compiler({"x": "encrypted", "y": "encrypted"})
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return np.sqrt(((x + y) ** 2) + 10).astype(np.int64)
@@ -65,7 +66,7 @@ def test_circuit_bad_run(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
@@ -153,7 +154,7 @@ def test_circuit_separate_args(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def function(x, y):
return x + y
@@ -185,7 +186,7 @@ def test_client_server_api(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted"})
@fhe.compiler({"x": "encrypted"})
def function(x):
return x + 42
@@ -248,7 +249,7 @@ def test_client_server_api_crt(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted"})
@fhe.compiler({"x": "encrypted"})
def function(x):
return x**2
@@ -299,7 +300,7 @@ def test_client_server_api_via_mlir(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted"})
@fhe.compiler({"x": "encrypted"})
def function(x):
return x + 42
@@ -354,7 +355,7 @@ def test_bad_server_save(helpers):
configuration = helpers.configuration().fork(jit=True)
@compiler({"x": "encrypted"})
@fhe.compiler({"x": "encrypted"})
def function(x):
return x + 42
@@ -374,7 +375,7 @@ def test_circuit_run_with_unused_arg(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y): # pylint: disable=unused-argument
return x + 10
@@ -399,7 +400,7 @@ def test_dataflow_circuit(helpers):
configuration = helpers.configuration().fork(dataflow_parallelize=True)
@compiler({"x": "encrypted", "y": "encrypted"})
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return (x**2) + (y // 2)
@@ -416,7 +417,7 @@ def test_circuit_sim_disabled(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
@@ -433,7 +434,7 @@ def test_circuit_fhe_exec_disabled(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
@@ -450,7 +451,7 @@ def test_circuit_fhe_exec_no_eval_keys(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
@@ -472,7 +473,7 @@ def test_circuit_eval_graph_scalar(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
lut = LookupTable(list(range(128)))
return lut[x + y]
@@ -489,7 +490,7 @@ def test_circuit_eval_graph_tensor(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
lut = LookupTable(list(range(128)))
return lut[x + y]
@@ -512,7 +513,7 @@ def test_circuit_compile_sim_only(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
lut = LookupTable(list(range(128)))
return lut[x + y]