feat: support parallelization in python

This commit is contained in:
youben11
2022-03-09 14:35:05 +01:00
committed by Ayoub Benaissa
parent 14faa4c7df
commit 5b37ec640c
6 changed files with 115 additions and 18 deletions

View File

@@ -33,6 +33,13 @@ else
CXX_COMPILER_OPTION=
endif
# don't run parallel python tests if compiler doesn't support it
ifeq ($(PARALLEL_EXECUTION_ENABLED),ON)
PYTHON_TESTS_MARKER=""
else
PYTHON_TESTS_MARKER="not parallel"
endif
$(BUILD_DIR)/configured.stamp:
cmake -B $(BUILD_DIR) -GNinja ../llvm-project/llvm/ \
$(CMAKE_CCACHE_OPTIONS) \
@@ -78,7 +85,7 @@ test-check: concretecompiler file-check not
$(BUILD_DIR)/bin/llvm-lit -v tests/
test-python: python-bindings concretecompiler
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/concretelang/python_packages/concretelang_core LD_PRELOAD=$(BUILD_DIR)/lib/libConcretelangRuntime.so pytest -vs tests/python
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/concretelang/python_packages/concretelang_core LD_PRELOAD=$(BUILD_DIR)/lib/libConcretelangRuntime.so pytest -vs -m $(PYTHON_TESTS_MARKER) tests/python
test: test-check test-end-to-end-jit test-python support-unit-test testlib-unit-test

View File

@@ -39,10 +39,11 @@ typedef struct executionArguments executionArguments;
// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if not
// null) as a shared library during compilation,
// a path to activate the use a cache for encryption keys for test purpose
// (unsecure).
// (unsecure), and a set of flags for parallelization.
MLIR_CAPI_EXPORTED mlir::concretelang::JitCompilerEngine::Lambda
buildLambda(const char *module, const char *funcName,
const char *runtimeLibPath, const char *keySetCachePath);
const char *runtimeLibPath, const char *keySetCachePath,
bool autoParallelize, bool loopParallelize, bool dfParallelize);
// Parse then print a textual representation of an MLIR module
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);

View File

@@ -43,11 +43,14 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def(pybind11::init())
.def_static("build_lambda",
[](std::string mlir_input, std::string func_name,
std::string runtime_lib_path,
std::string keysetcache_path) {
std::string runtime_lib_path, std::string keysetcache_path,
bool auto_parallelize, bool loop_parallelize,
bool df_parallelize) {
return buildLambda(mlir_input.c_str(), func_name.c_str(),
noEmptyStringPtr(runtime_lib_path),
noEmptyStringPtr(keysetcache_path));
noEmptyStringPtr(keysetcache_path),
auto_parallelize, loop_parallelize,
df_parallelize);
});
pybind11::class_<lambdaArgument>(m, "LambdaArgument")

View File

@@ -6,7 +6,9 @@ from collections.abc import Iterable
import os
from typing import List, Union
from mlir._mlir_libs._concretelang._compiler import JitCompilerEngine as _JitCompilerEngine
from mlir._mlir_libs._concretelang._compiler import (
JitCompilerEngine as _JitCompilerEngine,
)
from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
from mlir._mlir_libs._concretelang._compiler import library as _library
@@ -58,9 +60,11 @@ def round_trip(mlir_str: str) -> str:
raise TypeError("input must be an `str`")
return _round_trip(mlir_str)
_MLIR_MODULES_TYPE = 'mlir_modules must be an `iterable` of `str` or a `str'
def library(library_path: str, mlir_modules: Union['Iterable[str]', str]) -> str:
_MLIR_MODULES_TYPE = "mlir_modules must be an `iterable` of `str` or a `str"
def library(library_path: str, mlir_modules: Union["Iterable[str]", str]) -> str:
"""Compile the MLIR inputs to a library.
Args:
@@ -74,7 +78,7 @@ def library(library_path: str, mlir_modules: Union['Iterable[str]', str]) -> str
str: parsed MLIR input.
"""
if not isinstance(library_path, str):
raise TypeError('library_path must be a `str`')
raise TypeError("library_path must be a `str`")
if isinstance(mlir_modules, str):
mlir_modules = [mlir_modules]
elif isinstance(mlir_modules, list):
@@ -104,7 +108,9 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument
_LambdaArgument: lambda argument holding the appropriate value
"""
if not isinstance(value, ACCEPTED_TYPES):
raise TypeError("value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}")
raise TypeError(
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}"
)
if isinstance(value, ACCEPTED_INTS):
if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max):
raise TypeError(
@@ -128,8 +134,14 @@ class CompilerEngine:
self.compile_fhe(mlir_str)
def compile_fhe(
self, mlir_str: str, func_name: str = "main", runtime_lib_path: str = None,
self,
mlir_str: str,
func_name: str = "main",
runtime_lib_path: str = None,
unsecure_key_set_cache_path: str = None,
auto_parallelize: bool = False,
loop_parallelize: bool = False,
df_parallelize: bool = False,
):
"""Compile the MLIR input.
@@ -138,6 +150,9 @@ class CompilerEngine:
func_name (str): name of the function to set as entrypoint (default: main).
runtime_lib_path (str): path to the runtime lib (default: None).
unsecure_key_set_cache_path (str): path to the activate keyset caching (default: None).
auto_parallelize (bool): whether to activate auto-parallelization or not (default: False),
loop_parallelize (bool): whether to activate loop-parallelization or not (default: False),
df_parallelize (bool): whether to activate dataflow-parallelization or not (default: False),
Raises:
TypeError: if the argument is not an str.
@@ -152,14 +167,25 @@ class CompilerEngine:
raise TypeError(
"runtime_lib_path must be an str representing the path to the runtime lib"
)
if not all(
isinstance(flag, bool)
for flag in [auto_parallelize, loop_parallelize, df_parallelize]
):
raise TypeError(
"parallelization flags (auto_parallelize, loop_parallelize, df_parallelize), should be booleans"
)
unsecure_key_set_cache_path = unsecure_key_set_cache_path or ""
if not isinstance(unsecure_key_set_cache_path, str):
raise TypeError(
"unsecure_key_set_cache_path must be a str"
)
raise TypeError("unsecure_key_set_cache_path must be a str")
self._lambda = self._engine.build_lambda(
mlir_str, func_name, runtime_lib_path,
unsecure_key_set_cache_path)
mlir_str,
func_name,
runtime_lib_path,
unsecure_key_set_cache_path,
auto_parallelize,
loop_parallelize,
df_parallelize,
)
def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]:
"""Run the compiled code.

View File

@@ -15,13 +15,19 @@ using mlir::concretelang::JitCompilerEngine;
mlir::concretelang::JitCompilerEngine::Lambda
buildLambda(const char *module, const char *funcName,
const char *runtimeLibPath, const char *keySetCachePath) {
const char *runtimeLibPath, const char *keySetCachePath,
bool autoParallelize, bool loopParallelize, bool dfParallelize) {
// Set the runtime library path if not nullptr
llvm::Optional<llvm::StringRef> runtimeLibPathOptional = {};
if (runtimeLibPath != nullptr)
runtimeLibPathOptional = runtimeLibPath;
mlir::concretelang::JitCompilerEngine engine;
// Set parallelization flags
engine.setAutoParallelize(autoParallelize);
engine.setLoopParallelize(loopParallelize);
engine.setDataflowParallelize(dfParallelize);
using KeySetCache = mlir::concretelang::KeySetCache;
using optKeySetCache = llvm::Optional<mlir::concretelang::KeySetCache>;
auto cacheOpt = optKeySetCache();

View File

@@ -0,0 +1,54 @@
import os
import tempfile
import pytest
import numpy as np
from concrete.compiler import CompilerEngine
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
@pytest.mark.parallel
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
[
pytest.param(
"""
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
""",
(5, 7),
12,
id="add_eint_int",
),
pytest.param(
"""
func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
{
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
(tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7>
return %ret : !FHE.eint<7>
}
""",
(
np.array([1, 2, 3, 4], dtype=np.uint8),
np.array([4, 3, 2, 1], dtype=np.uint8),
),
20,
id="dot_eint_int_uint8",
),
],
)
def test_compile_and_run_parallel(mlir_input, args, expected_result):
engine = CompilerEngine()
engine.compile_fhe(
mlir_input,
unsecure_key_set_cache_path=KEY_SET_CACHE_PATH,
auto_parallelize=True
)
if isinstance(expected_result, int):
assert engine.run(*args) == expected_result
else:
# numpy array
assert np.all(engine.run(*args) == expected_result)