diff --git a/concrete/numpy/compilation/configuration.py b/concrete/numpy/compilation/configuration.py index 5da5311f5..ab786d794 100644 --- a/concrete/numpy/compilation/configuration.py +++ b/concrete/numpy/compilation/configuration.py @@ -63,7 +63,7 @@ class Configuration: use_insecure_key_cache: bool = False, insecure_key_cache_location: Optional[Union[Path, str]] = None, loop_parallelize: bool = True, - dataflow_parallelize: bool = False, + dataflow_parallelize: bool = True, auto_parallelize: bool = False, jit: bool = False, p_error: Optional[float] = None, diff --git a/concrete/numpy/compilation/server.py b/concrete/numpy/compilation/server.py index e9313d1c8..f01ad4dac 100644 --- a/concrete/numpy/compilation/server.py +++ b/concrete/numpy/compilation/server.py @@ -8,6 +8,7 @@ import tempfile from pathlib import Path from typing import List, Optional, Union +import concrete.compiler from concrete.compiler import ( CompilationFeedback, CompilationOptions, @@ -95,6 +96,9 @@ class Server: options.set_dataflow_parallelize(configuration.dataflow_parallelize) options.set_auto_parallelize(configuration.auto_parallelize) + if configuration.auto_parallelize or configuration.dataflow_parallelize: + concrete.compiler.init_dfr() + global_p_error_is_set = configuration.global_p_error is not None p_error_is_set = configuration.p_error is not None diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index a26aeba84..fefdd9c07 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -326,3 +326,20 @@ def test_circuit_run_with_unused_arg(helpers): assert circuit.encrypt_run_decrypt(10, 0) == 20 assert circuit.encrypt_run_decrypt(10, 10) == 20 assert circuit.encrypt_run_decrypt(10, 20) == 20 + + +def test_dataflow_circuit(helpers): + """ + Test execution with dataflow_parallelize=True. + """ + + configuration = helpers.configuration().fork(dataflow_parallelize=True) + + @compiler({"x": "encrypted", "y": "encrypted"}) + def f(x, y): + return (x**2) + (y // 2) + + inputset = [(np.random.randint(0, 2**3), np.random.randint(0, 2**3)) for _ in range(100)] + circuit = f.compile(inputset, configuration) + + assert circuit.encrypt_run_decrypt(5, 6) == 28