From 91d41a2ff82b2f0e5f91e8578c8d792a16486f73 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 8 Dec 2022 10:54:16 +0100 Subject: [PATCH] fix: add call to init dfr from python Co-authored-by: Antoniu Pop --- .../Bindings/Python/CompilerEngine.h | 5 +++-- .../lib/Bindings/Python/CompilerAPIModule.cpp | 4 +++- compiler/lib/Bindings/Python/CompilerEngine.cpp | 8 ++++---- .../Python/concrete/compiler/__init__.py | 17 ++++++++++++++--- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/compiler/include/concretelang/Bindings/Python/CompilerEngine.h b/compiler/include/concretelang/Bindings/Python/CompilerEngine.h index 9bd602875..6665411fa 100644 --- a/compiler/include/concretelang/Bindings/Python/CompilerEngine.h +++ b/compiler/include/concretelang/Bindings/Python/CompilerEngine.h @@ -147,8 +147,9 @@ MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize( /// Parse then print a textual representation of an MLIR module MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); -/// Terminate parallelization -MLIR_CAPI_EXPORTED void terminateParallelization(); +/// Terminate/Init dataflow parallelization +MLIR_CAPI_EXPORTED void terminateDataflowParallelization(); +MLIR_CAPI_EXPORTED void initDataflowParallelization(); /// Create a lambdaArgument from a tensor of different data types MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8( diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 1822b1644..83e1acd06 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -30,7 +30,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( m.def("round_trip", [](std::string mlir_input) { return roundTrip(mlir_input.c_str()); }); - m.def("terminate_parallelization", &terminateParallelization); + m.def("terminate_df_parallelization", &terminateDataflowParallelization); + + m.def("init_df_parallelization", &initDataflowParallelization); pybind11::class_(m, "CompilationOptions") .def(pybind11::init( diff --git a/compiler/lib/Bindings/Python/CompilerEngine.cpp b/compiler/lib/Bindings/Python/CompilerEngine.cpp index 80598cc1a..b050b38b0 100644 --- a/compiler/lib/Bindings/Python/CompilerEngine.cpp +++ b/compiler/lib/Bindings/Python/CompilerEngine.cpp @@ -252,10 +252,10 @@ clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms) { return jsonParams; } -MLIR_CAPI_EXPORTED void terminateParallelization() { -#ifdef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED - _dfr_terminate(); -#endif +MLIR_CAPI_EXPORTED void terminateDataflowParallelization() { _dfr_terminate(); } + +MLIR_CAPI_EXPORTED void initDataflowParallelization() { + mlir::concretelang::dfr::_dfr_set_required(true); } MLIR_CAPI_EXPORTED std::string roundTrip(const char *module) { diff --git a/compiler/lib/Bindings/Python/concrete/compiler/__init__.py b/compiler/lib/Bindings/Python/concrete/compiler/__init__.py index b9b6d063b..48d9cc72a 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/__init__.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/__init__.py @@ -6,7 +6,8 @@ import atexit # pylint: disable=no-name-in-module,import-error from mlir._mlir_libs._concretelang._compiler import ( - terminate_parallelization as _terminate_parallelization, + terminate_df_parallelization as _terminate_df_parallelization, + init_df_parallelization as _init_df_parallelization, ) from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip @@ -30,8 +31,18 @@ from .library_support import LibrarySupport from .evaluation_keys import EvaluationKeys -# Terminate parallelization in the compiler (if init) during cleanup -atexit.register(_terminate_parallelization) +def init_dfr(): + """Initialize dataflow parallelization. + + It is not always required to initialize the dataflow runtime as it can be implicitely done + during compilation. However, it is required in case no compilation has previously been done + and the runtime is needed""" + _init_df_parallelization() + + +# Cleanly terminate the dataflow runtime if it has been initialized +# (does nothing otherwise) +atexit.register(_terminate_df_parallelization) def round_trip(mlir_str: str) -> str: