mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: setup init/termination of parallel execution in python
This commit is contained in:
@@ -53,6 +53,12 @@ MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument invokeLambda(lambda l,
|
||||
executionArguments args);
|
||||
|
||||
// Initialize and terminate parallelization. Init can be called only once (later
|
||||
// calls might be ignored by the runtime). You shouldn't reinit after
|
||||
// termination.
|
||||
MLIR_CAPI_EXPORTED void initParallelization();
|
||||
MLIR_CAPI_EXPORTED void terminateParallelization();
|
||||
|
||||
// Create a lambdaArgument from a tensor of different data types
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8(
|
||||
std::vector<uint8_t> data, std::vector<int64_t> dimensions);
|
||||
|
||||
@@ -39,6 +39,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
return library(library_path, mlir_modules);
|
||||
});
|
||||
|
||||
m.def("init_parallelization", &initParallelization);
|
||||
m.def("terminate_parallelization", &terminateParallelization);
|
||||
|
||||
pybind11::class_<JitCompilerEngine>(m, "JitCompilerEngine")
|
||||
.def(pybind11::init())
|
||||
.def_static("build_lambda",
|
||||
|
||||
@@ -4,10 +4,13 @@
|
||||
"""Compiler submodule"""
|
||||
from collections.abc import Iterable
|
||||
import os
|
||||
import atexit
|
||||
from typing import List, Union
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
JitCompilerEngine as _JitCompilerEngine,
|
||||
init_parallelization as _init_parallelization,
|
||||
terminate_parallelization as _terminate_parallelization,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument
|
||||
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
|
||||
@@ -20,6 +23,10 @@ ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS
|
||||
ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS
|
||||
|
||||
|
||||
# Terminate parallelization in the compiler (if init) during cleanup
|
||||
atexit.register(_terminate_parallelization)
|
||||
|
||||
|
||||
def _lookup_runtime_lib() -> str:
|
||||
"""Try to find the absolute path to the runtime library.
|
||||
|
||||
@@ -177,6 +184,10 @@ class CompilerEngine:
|
||||
unsecure_key_set_cache_path = unsecure_key_set_cache_path or ""
|
||||
if not isinstance(unsecure_key_set_cache_path, str):
|
||||
raise TypeError("unsecure_key_set_cache_path must be a str")
|
||||
|
||||
if any([auto_parallelize, loop_parallelize, df_parallelize]):
|
||||
# Multiple calls should be guarded in the compiler and only result in a single init
|
||||
_init_parallelization()
|
||||
self._lambda = self._engine.build_lambda(
|
||||
mlir_str,
|
||||
func_name,
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
set(LLVM_OPTIONAL_SOURCES CompilerEngine.cpp)
|
||||
|
||||
if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED)
|
||||
add_compile_options(-DCONCRETELANG_PARALLEL_EXECUTION_ENABLED)
|
||||
endif()
|
||||
|
||||
add_mlir_public_c_api_library(CONCRETELANGCAPISupport
|
||||
|
||||
CompilerEngine.cpp
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "concretelang-c/Support/CompilerEngine.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Runtime/runtime_api.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
@@ -47,6 +48,24 @@ buildLambda(const char *module, const char *funcName,
|
||||
return std::move(*lambdaOrErr);
|
||||
}
|
||||
|
||||
void initParallelization() {
|
||||
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
_dfr_pre_main();
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"This package was built without parallelization support");
|
||||
#endif
|
||||
}
|
||||
|
||||
void terminateParallelization() {
|
||||
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
_dfr_post_main();
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"This package was built without parallelization support");
|
||||
#endif
|
||||
}
|
||||
|
||||
lambdaArgument invokeLambda(lambda l, executionArguments args) {
|
||||
mlir::concretelang::JitCompilerEngine::Lambda *lambda_ptr =
|
||||
(mlir::concretelang::JitCompilerEngine::Lambda *)l.ptr;
|
||||
|
||||
Reference in New Issue
Block a user