fix: add call to init dfr from python

Co-authored-by: Antoniu Pop <antoniu.pop@zama.ai>
This commit is contained in:
youben11
2022-12-08 10:54:16 +01:00
committed by Ayoub Benaissa
parent add68ccf84
commit 91d41a2ff8
4 changed files with 24 additions and 10 deletions

View File

@@ -147,8 +147,9 @@ MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize(
/// Parse then print a textual representation of an MLIR module
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
/// Terminate parallelization
MLIR_CAPI_EXPORTED void terminateParallelization();
/// Terminate/Init dataflow parallelization
MLIR_CAPI_EXPORTED void terminateDataflowParallelization();
MLIR_CAPI_EXPORTED void initDataflowParallelization();
/// Create a lambdaArgument from a tensor of different data types
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8(

View File

@@ -30,7 +30,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
m.def("round_trip",
[](std::string mlir_input) { return roundTrip(mlir_input.c_str()); });
m.def("terminate_parallelization", &terminateParallelization);
m.def("terminate_df_parallelization", &terminateDataflowParallelization);
m.def("init_df_parallelization", &initDataflowParallelization);
pybind11::class_<CompilationOptions>(m, "CompilationOptions")
.def(pybind11::init(

View File

@@ -252,10 +252,10 @@ clientParametersSerialize(mlir::concretelang::ClientParameters &params) {
return jsonParams;
}
MLIR_CAPI_EXPORTED void terminateParallelization() {
#ifdef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED
_dfr_terminate();
#endif
MLIR_CAPI_EXPORTED void terminateDataflowParallelization() { _dfr_terminate(); }
MLIR_CAPI_EXPORTED void initDataflowParallelization() {
mlir::concretelang::dfr::_dfr_set_required(true);
}
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module) {

View File

@@ -6,7 +6,8 @@ import atexit
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
terminate_parallelization as _terminate_parallelization,
terminate_df_parallelization as _terminate_df_parallelization,
init_df_parallelization as _init_df_parallelization,
)
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
@@ -30,8 +31,18 @@ from .library_support import LibrarySupport
from .evaluation_keys import EvaluationKeys
# Terminate parallelization in the compiler (if init) during cleanup
atexit.register(_terminate_parallelization)
def init_dfr():
"""Initialize dataflow parallelization.
It is not always required to initialize the dataflow runtime as it can be implicitely done
during compilation. However, it is required in case no compilation has previously been done
and the runtime is needed"""
_init_df_parallelization()
# Cleanly terminate the dataflow runtime if it has been initialized
# (does nothing otherwise)
atexit.register(_terminate_df_parallelization)
def round_trip(mlir_str: str) -> str: