From 44ebd426f947f5f3702716a027043e149ecb8458 Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 11 Mar 2022 15:56:46 +0100 Subject: [PATCH] feat: setup init/termination of parallel execution in python --- .../concretelang-c/Support/CompilerEngine.h | 6 ++++++ .../lib/Bindings/Python/CompilerAPIModule.cpp | 3 +++ .../lib/Bindings/Python/concrete/compiler.py | 11 +++++++++++ compiler/lib/CAPI/Support/CMakeLists.txt | 4 ++++ compiler/lib/CAPI/Support/CompilerEngine.cpp | 19 +++++++++++++++++++ 5 files changed, 43 insertions(+) diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index ac0497a36..8f97dcd6f 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -53,6 +53,12 @@ MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); MLIR_CAPI_EXPORTED lambdaArgument invokeLambda(lambda l, executionArguments args); +// Initialize and terminate parallelization. Init can be called only once (later +// calls might be ignored by the runtime). You shouldn't reinit after +// termination. +MLIR_CAPI_EXPORTED void initParallelization(); +MLIR_CAPI_EXPORTED void terminateParallelization(); + // Create a lambdaArgument from a tensor of different data types MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8( std::vector data, std::vector dimensions); diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 512c2f1ff..39afa9bd5 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -39,6 +39,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return library(library_path, mlir_modules); }); + m.def("init_parallelization", &initParallelization); + m.def("terminate_parallelization", &terminateParallelization); + pybind11::class_(m, "JitCompilerEngine") .def(pybind11::init()) .def_static("build_lambda", diff --git a/compiler/lib/Bindings/Python/concrete/compiler.py b/compiler/lib/Bindings/Python/concrete/compiler.py index 9f139da57..cad8f6830 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler.py +++ b/compiler/lib/Bindings/Python/concrete/compiler.py @@ -4,10 +4,13 @@ """Compiler submodule""" from collections.abc import Iterable import os +import atexit from typing import List, Union from mlir._mlir_libs._concretelang._compiler import ( JitCompilerEngine as _JitCompilerEngine, + init_parallelization as _init_parallelization, + terminate_parallelization as _terminate_parallelization, ) from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip @@ -20,6 +23,10 @@ ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS +# Terminate parallelization in the compiler (if init) during cleanup +atexit.register(_terminate_parallelization) + + def _lookup_runtime_lib() -> str: """Try to find the absolute path to the runtime library. @@ -177,6 +184,10 @@ class CompilerEngine: unsecure_key_set_cache_path = unsecure_key_set_cache_path or "" if not isinstance(unsecure_key_set_cache_path, str): raise TypeError("unsecure_key_set_cache_path must be a str") + + if any([auto_parallelize, loop_parallelize, df_parallelize]): + # Multiple calls should be guarded in the compiler and only result in a single init + _init_parallelization() self._lambda = self._engine.build_lambda( mlir_str, func_name, diff --git a/compiler/lib/CAPI/Support/CMakeLists.txt b/compiler/lib/CAPI/Support/CMakeLists.txt index f2d855e27..5baf0c241 100644 --- a/compiler/lib/CAPI/Support/CMakeLists.txt +++ b/compiler/lib/CAPI/Support/CMakeLists.txt @@ -1,5 +1,9 @@ set(LLVM_OPTIONAL_SOURCES CompilerEngine.cpp) +if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED) + add_compile_options(-DCONCRETELANG_PARALLEL_EXECUTION_ENABLED) +endif() + add_mlir_public_c_api_library(CONCRETELANGCAPISupport CompilerEngine.cpp diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 21d7d9072..bef38fa68 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -7,6 +7,7 @@ #include "concretelang-c/Support/CompilerEngine.h" #include "concretelang/ClientLib/KeySetCache.h" +#include "concretelang/Runtime/runtime_api.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Jit.h" #include "concretelang/Support/JitCompilerEngine.h" @@ -47,6 +48,24 @@ buildLambda(const char *module, const char *funcName, return std::move(*lambdaOrErr); } +void initParallelization() { +#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED + _dfr_pre_main(); +#else + throw std::runtime_error( + "This package was built without parallelization support"); +#endif +} + +void terminateParallelization() { +#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED + _dfr_post_main(); +#else + throw std::runtime_error( + "This package was built without parallelization support"); +#endif +} + lambdaArgument invokeLambda(lambda l, executionArguments args) { mlir::concretelang::JitCompilerEngine::Lambda *lambda_ptr = (mlir::concretelang::JitCompilerEngine::Lambda *)l.ptr;