mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix: add call to init dfr from python
Co-authored-by: Antoniu Pop <antoniu.pop@zama.ai>
This commit is contained in:
@@ -147,8 +147,9 @@ MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize(
|
||||
/// Parse then print a textual representation of an MLIR module
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
/// Terminate parallelization
|
||||
MLIR_CAPI_EXPORTED void terminateParallelization();
|
||||
/// Terminate/Init dataflow parallelization
|
||||
MLIR_CAPI_EXPORTED void terminateDataflowParallelization();
|
||||
MLIR_CAPI_EXPORTED void initDataflowParallelization();
|
||||
|
||||
/// Create a lambdaArgument from a tensor of different data types
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8(
|
||||
|
||||
@@ -30,7 +30,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
m.def("round_trip",
|
||||
[](std::string mlir_input) { return roundTrip(mlir_input.c_str()); });
|
||||
|
||||
m.def("terminate_parallelization", &terminateParallelization);
|
||||
m.def("terminate_df_parallelization", &terminateDataflowParallelization);
|
||||
|
||||
m.def("init_df_parallelization", &initDataflowParallelization);
|
||||
|
||||
pybind11::class_<CompilationOptions>(m, "CompilationOptions")
|
||||
.def(pybind11::init(
|
||||
|
||||
@@ -252,10 +252,10 @@ clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms) {
|
||||
return jsonParams;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED void terminateParallelization() {
|
||||
#ifdef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED
|
||||
_dfr_terminate();
|
||||
#endif
|
||||
MLIR_CAPI_EXPORTED void terminateDataflowParallelization() { _dfr_terminate(); }
|
||||
|
||||
MLIR_CAPI_EXPORTED void initDataflowParallelization() {
|
||||
mlir::concretelang::dfr::_dfr_set_required(true);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module) {
|
||||
|
||||
@@ -6,7 +6,8 @@ import atexit
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
terminate_parallelization as _terminate_parallelization,
|
||||
terminate_df_parallelization as _terminate_df_parallelization,
|
||||
init_df_parallelization as _init_df_parallelization,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
|
||||
|
||||
@@ -30,8 +31,18 @@ from .library_support import LibrarySupport
|
||||
from .evaluation_keys import EvaluationKeys
|
||||
|
||||
|
||||
# Terminate parallelization in the compiler (if init) during cleanup
|
||||
atexit.register(_terminate_parallelization)
|
||||
def init_dfr():
|
||||
"""Initialize dataflow parallelization.
|
||||
|
||||
It is not always required to initialize the dataflow runtime as it can be implicitely done
|
||||
during compilation. However, it is required in case no compilation has previously been done
|
||||
and the runtime is needed"""
|
||||
_init_df_parallelization()
|
||||
|
||||
|
||||
# Cleanly terminate the dataflow runtime if it has been initialized
|
||||
# (does nothing otherwise)
|
||||
atexit.register(_terminate_df_parallelization)
|
||||
|
||||
|
||||
def round_trip(mlir_str: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user