mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: support parallelization in python
This commit is contained in:
@@ -33,6 +33,13 @@ else
|
||||
CXX_COMPILER_OPTION=
|
||||
endif
|
||||
|
||||
# don't run parallel python tests if compiler doesn't support it
|
||||
ifeq ($(PARALLEL_EXECUTION_ENABLED),ON)
|
||||
PYTHON_TESTS_MARKER=""
|
||||
else
|
||||
PYTHON_TESTS_MARKER="not parallel"
|
||||
endif
|
||||
|
||||
$(BUILD_DIR)/configured.stamp:
|
||||
cmake -B $(BUILD_DIR) -GNinja ../llvm-project/llvm/ \
|
||||
$(CMAKE_CCACHE_OPTIONS) \
|
||||
@@ -78,7 +85,7 @@ test-check: concretecompiler file-check not
|
||||
$(BUILD_DIR)/bin/llvm-lit -v tests/
|
||||
|
||||
test-python: python-bindings concretecompiler
|
||||
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/concretelang/python_packages/concretelang_core LD_PRELOAD=$(BUILD_DIR)/lib/libConcretelangRuntime.so pytest -vs tests/python
|
||||
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/concretelang/python_packages/concretelang_core LD_PRELOAD=$(BUILD_DIR)/lib/libConcretelangRuntime.so pytest -vs -m $(PYTHON_TESTS_MARKER) tests/python
|
||||
|
||||
test: test-check test-end-to-end-jit test-python support-unit-test testlib-unit-test
|
||||
|
||||
|
||||
@@ -39,10 +39,11 @@ typedef struct executionArguments executionArguments;
|
||||
// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if not
|
||||
// null) as a shared library during compilation,
|
||||
// a path to activate the use a cache for encryption keys for test purpose
|
||||
// (unsecure).
|
||||
// (unsecure), and a set of flags for parallelization.
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName,
|
||||
const char *runtimeLibPath, const char *keySetCachePath);
|
||||
const char *runtimeLibPath, const char *keySetCachePath,
|
||||
bool autoParallelize, bool loopParallelize, bool dfParallelize);
|
||||
|
||||
// Parse then print a textual representation of an MLIR module
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
@@ -43,11 +43,14 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def(pybind11::init())
|
||||
.def_static("build_lambda",
|
||||
[](std::string mlir_input, std::string func_name,
|
||||
std::string runtime_lib_path,
|
||||
std::string keysetcache_path) {
|
||||
std::string runtime_lib_path, std::string keysetcache_path,
|
||||
bool auto_parallelize, bool loop_parallelize,
|
||||
bool df_parallelize) {
|
||||
return buildLambda(mlir_input.c_str(), func_name.c_str(),
|
||||
noEmptyStringPtr(runtime_lib_path),
|
||||
noEmptyStringPtr(keysetcache_path));
|
||||
noEmptyStringPtr(keysetcache_path),
|
||||
auto_parallelize, loop_parallelize,
|
||||
df_parallelize);
|
||||
});
|
||||
|
||||
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
|
||||
|
||||
@@ -6,7 +6,9 @@ from collections.abc import Iterable
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import JitCompilerEngine as _JitCompilerEngine
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
JitCompilerEngine as _JitCompilerEngine,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument
|
||||
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
|
||||
from mlir._mlir_libs._concretelang._compiler import library as _library
|
||||
@@ -58,9 +60,11 @@ def round_trip(mlir_str: str) -> str:
|
||||
raise TypeError("input must be an `str`")
|
||||
return _round_trip(mlir_str)
|
||||
|
||||
_MLIR_MODULES_TYPE = 'mlir_modules must be an `iterable` of `str` or a `str'
|
||||
|
||||
def library(library_path: str, mlir_modules: Union['Iterable[str]', str]) -> str:
|
||||
_MLIR_MODULES_TYPE = "mlir_modules must be an `iterable` of `str` or a `str"
|
||||
|
||||
|
||||
def library(library_path: str, mlir_modules: Union["Iterable[str]", str]) -> str:
|
||||
"""Compile the MLIR inputs to a library.
|
||||
|
||||
Args:
|
||||
@@ -74,7 +78,7 @@ def library(library_path: str, mlir_modules: Union['Iterable[str]', str]) -> str
|
||||
str: parsed MLIR input.
|
||||
"""
|
||||
if not isinstance(library_path, str):
|
||||
raise TypeError('library_path must be a `str`')
|
||||
raise TypeError("library_path must be a `str`")
|
||||
if isinstance(mlir_modules, str):
|
||||
mlir_modules = [mlir_modules]
|
||||
elif isinstance(mlir_modules, list):
|
||||
@@ -104,7 +108,9 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument
|
||||
_LambdaArgument: lambda argument holding the appropriate value
|
||||
"""
|
||||
if not isinstance(value, ACCEPTED_TYPES):
|
||||
raise TypeError("value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}")
|
||||
raise TypeError(
|
||||
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}"
|
||||
)
|
||||
if isinstance(value, ACCEPTED_INTS):
|
||||
if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max):
|
||||
raise TypeError(
|
||||
@@ -128,8 +134,14 @@ class CompilerEngine:
|
||||
self.compile_fhe(mlir_str)
|
||||
|
||||
def compile_fhe(
|
||||
self, mlir_str: str, func_name: str = "main", runtime_lib_path: str = None,
|
||||
self,
|
||||
mlir_str: str,
|
||||
func_name: str = "main",
|
||||
runtime_lib_path: str = None,
|
||||
unsecure_key_set_cache_path: str = None,
|
||||
auto_parallelize: bool = False,
|
||||
loop_parallelize: bool = False,
|
||||
df_parallelize: bool = False,
|
||||
):
|
||||
"""Compile the MLIR input.
|
||||
|
||||
@@ -138,6 +150,9 @@ class CompilerEngine:
|
||||
func_name (str): name of the function to set as entrypoint (default: main).
|
||||
runtime_lib_path (str): path to the runtime lib (default: None).
|
||||
unsecure_key_set_cache_path (str): path to the activate keyset caching (default: None).
|
||||
auto_parallelize (bool): whether to activate auto-parallelization or not (default: False),
|
||||
loop_parallelize (bool): whether to activate loop-parallelization or not (default: False),
|
||||
df_parallelize (bool): whether to activate dataflow-parallelization or not (default: False),
|
||||
|
||||
Raises:
|
||||
TypeError: if the argument is not an str.
|
||||
@@ -152,14 +167,25 @@ class CompilerEngine:
|
||||
raise TypeError(
|
||||
"runtime_lib_path must be an str representing the path to the runtime lib"
|
||||
)
|
||||
if not all(
|
||||
isinstance(flag, bool)
|
||||
for flag in [auto_parallelize, loop_parallelize, df_parallelize]
|
||||
):
|
||||
raise TypeError(
|
||||
"parallelization flags (auto_parallelize, loop_parallelize, df_parallelize), should be booleans"
|
||||
)
|
||||
unsecure_key_set_cache_path = unsecure_key_set_cache_path or ""
|
||||
if not isinstance(unsecure_key_set_cache_path, str):
|
||||
raise TypeError(
|
||||
"unsecure_key_set_cache_path must be a str"
|
||||
)
|
||||
raise TypeError("unsecure_key_set_cache_path must be a str")
|
||||
self._lambda = self._engine.build_lambda(
|
||||
mlir_str, func_name, runtime_lib_path,
|
||||
unsecure_key_set_cache_path)
|
||||
mlir_str,
|
||||
func_name,
|
||||
runtime_lib_path,
|
||||
unsecure_key_set_cache_path,
|
||||
auto_parallelize,
|
||||
loop_parallelize,
|
||||
df_parallelize,
|
||||
)
|
||||
|
||||
def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]:
|
||||
"""Run the compiled code.
|
||||
|
||||
@@ -15,13 +15,19 @@ using mlir::concretelang::JitCompilerEngine;
|
||||
|
||||
mlir::concretelang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName,
|
||||
const char *runtimeLibPath, const char *keySetCachePath) {
|
||||
const char *runtimeLibPath, const char *keySetCachePath,
|
||||
bool autoParallelize, bool loopParallelize, bool dfParallelize) {
|
||||
// Set the runtime library path if not nullptr
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPathOptional = {};
|
||||
if (runtimeLibPath != nullptr)
|
||||
runtimeLibPathOptional = runtimeLibPath;
|
||||
mlir::concretelang::JitCompilerEngine engine;
|
||||
|
||||
// Set parallelization flags
|
||||
engine.setAutoParallelize(autoParallelize);
|
||||
engine.setLoopParallelize(loopParallelize);
|
||||
engine.setDataflowParallelize(dfParallelize);
|
||||
|
||||
using KeySetCache = mlir::concretelang::KeySetCache;
|
||||
using optKeySetCache = llvm::Optional<mlir::concretelang::KeySetCache>;
|
||||
auto cacheOpt = optKeySetCache();
|
||||
|
||||
54
compiler/tests/python/test_compiler_engine_parallel.py
Normal file
54
compiler/tests/python/test_compiler_engine_parallel.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilerEngine
|
||||
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
|
||||
|
||||
@pytest.mark.parallel
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(5, 7),
|
||||
12,
|
||||
id="add_eint_int",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
|
||||
{
|
||||
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7>
|
||||
return %ret : !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint8),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint8),
|
||||
),
|
||||
20,
|
||||
id="dot_eint_int_uint8",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_parallel(mlir_input, args, expected_result):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(
|
||||
mlir_input,
|
||||
unsecure_key_set_cache_path=KEY_SET_CACHE_PATH,
|
||||
auto_parallelize=True
|
||||
)
|
||||
if isinstance(expected_result, int):
|
||||
assert engine.run(*args) == expected_result
|
||||
else:
|
||||
# numpy array
|
||||
assert np.all(engine.run(*args) == expected_result)
|
||||
Reference in New Issue
Block a user