feat: setup init/termination of parallel execution in python

This commit is contained in:
youben11
2022-03-11 15:56:46 +01:00
committed by Ayoub Benaissa
parent 2f31edef7f
commit 44ebd426f9
5 changed files with 43 additions and 0 deletions

View File

@@ -53,6 +53,12 @@ MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
MLIR_CAPI_EXPORTED lambdaArgument invokeLambda(lambda l,
executionArguments args);
// Initialize and terminate parallelization. Init can be called only once (later
// calls might be ignored by the runtime). You shouldn't reinit after
// termination.
MLIR_CAPI_EXPORTED void initParallelization();
MLIR_CAPI_EXPORTED void terminateParallelization();
// Create a lambdaArgument from a tensor of different data types
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8(
std::vector<uint8_t> data, std::vector<int64_t> dimensions);

View File

@@ -39,6 +39,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
return library(library_path, mlir_modules);
});
m.def("init_parallelization", &initParallelization);
m.def("terminate_parallelization", &terminateParallelization);
pybind11::class_<JitCompilerEngine>(m, "JitCompilerEngine")
.def(pybind11::init())
.def_static("build_lambda",

View File

@@ -4,10 +4,13 @@
"""Compiler submodule"""
from collections.abc import Iterable
import os
import atexit
from typing import List, Union
from mlir._mlir_libs._concretelang._compiler import (
JitCompilerEngine as _JitCompilerEngine,
init_parallelization as _init_parallelization,
terminate_parallelization as _terminate_parallelization,
)
from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
@@ -20,6 +23,10 @@ ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS
ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS
# Terminate parallelization in the compiler (if init) during cleanup
atexit.register(_terminate_parallelization)
def _lookup_runtime_lib() -> str:
"""Try to find the absolute path to the runtime library.
@@ -177,6 +184,10 @@ class CompilerEngine:
unsecure_key_set_cache_path = unsecure_key_set_cache_path or ""
if not isinstance(unsecure_key_set_cache_path, str):
raise TypeError("unsecure_key_set_cache_path must be a str")
if any([auto_parallelize, loop_parallelize, df_parallelize]):
# Multiple calls should be guarded in the compiler and only result in a single init
_init_parallelization()
self._lambda = self._engine.build_lambda(
mlir_str,
func_name,

View File

@@ -1,5 +1,9 @@
set(LLVM_OPTIONAL_SOURCES CompilerEngine.cpp)
if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED)
add_compile_options(-DCONCRETELANG_PARALLEL_EXECUTION_ENABLED)
endif()
add_mlir_public_c_api_library(CONCRETELANGCAPISupport
CompilerEngine.cpp

View File

@@ -7,6 +7,7 @@
#include "concretelang-c/Support/CompilerEngine.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Runtime/runtime_api.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/JitCompilerEngine.h"
@@ -47,6 +48,24 @@ buildLambda(const char *module, const char *funcName,
return std::move(*lambdaOrErr);
}
void initParallelization() {
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
_dfr_pre_main();
#else
throw std::runtime_error(
"This package was built without parallelization support");
#endif
}
void terminateParallelization() {
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
_dfr_post_main();
#else
throw std::runtime_error(
"This package was built without parallelization support");
#endif
}
lambdaArgument invokeLambda(lambda l, executionArguments args) {
mlir::concretelang::JitCompilerEngine::Lambda *lambda_ptr =
(mlir::concretelang::JitCompilerEngine::Lambda *)l.ptr;