mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
cleanup(python): Cleanup python bindings after refactoring of CompilerEngine
This commit is contained in:
@@ -7,15 +7,10 @@ import os
|
||||
import atexit
|
||||
from typing import List, Union
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
JitCompilerEngine as _JitCompilerEngine,
|
||||
terminate_parallelization as _terminate_parallelization,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument
|
||||
from mlir._mlir_libs._concretelang._compiler import terminate_parallelization as _terminate_parallelization
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
|
||||
from mlir._mlir_libs._concretelang._compiler import library as _library
|
||||
from mlir._mlir_libs._concretelang._compiler import JITLambdaSupport
|
||||
from mlir._mlir_libs._concretelang._compiler import LibraryLambdaSupport
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import ClientParameters
|
||||
@@ -25,10 +20,13 @@ from mlir._mlir_libs._concretelang._compiler import KeySetCache
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import PublicResult
|
||||
from mlir._mlir_libs._concretelang._compiler import PublicArguments
|
||||
from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import JITLambdaSupport as _JITLambdaSupport
|
||||
from mlir._mlir_libs._concretelang._compiler import JitCompilationResult
|
||||
from mlir._mlir_libs._concretelang._compiler import JITLambda
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import LibraryLambdaSupport as _LibraryLambdaSupport
|
||||
from mlir._mlir_libs._concretelang._compiler import LibraryCompilationResult
|
||||
from mlir._mlir_libs._concretelang._compiler import LibraryLambda
|
||||
import numpy as np
|
||||
@@ -85,70 +83,6 @@ def round_trip(mlir_str: str) -> str:
|
||||
return _round_trip(mlir_str)
|
||||
|
||||
|
||||
_MLIR_MODULES_TYPE = 'mlir_modules must be an `iterable` of `str` or a `str'
|
||||
|
||||
|
||||
def library(library_path: str, mlir_modules: Union['Iterable[str]', str]) -> str:
|
||||
"""Compile the MLIR inputs to a library.
|
||||
|
||||
Args:
|
||||
library_path (str): destination path of the library
|
||||
mlir_modules (list[str]|str): code of MLIR modules
|
||||
|
||||
Raises:
|
||||
TypeError: if arguments have incorrect types.
|
||||
|
||||
Returns:
|
||||
str: parsed MLIR input.
|
||||
"""
|
||||
if not isinstance(library_path, str):
|
||||
raise TypeError("library_path must be a `str`")
|
||||
if isinstance(mlir_modules, str):
|
||||
mlir_modules = [mlir_modules]
|
||||
elif isinstance(mlir_modules, list):
|
||||
pass
|
||||
elif isinstance(mlir_modules, Iterable):
|
||||
mlir_modules = list(mlir_modules)
|
||||
else:
|
||||
mlir_modules = [None]
|
||||
raise TypeError(_MLIR_MODULES_TYPE)
|
||||
|
||||
if not all(isinstance(m, str) for m in mlir_modules):
|
||||
raise TypeError(_MLIR_MODULES_TYPE)
|
||||
|
||||
return _library(library_path, mlir_modules)
|
||||
|
||||
|
||||
def create_execution_argument(value: Union[int, np.ndarray]) -> _LambdaArgument:
|
||||
"""Create an execution argument holding either an int or tensor value.
|
||||
|
||||
Args:
|
||||
value (Union[int, numpy.array]): value of the argument, either an int, or a numpy array
|
||||
|
||||
Raises:
|
||||
TypeError: if the values aren't in the expected range, or using a wrong type
|
||||
|
||||
Returns:
|
||||
_LambdaArgument: lambda argument holding the appropriate value
|
||||
"""
|
||||
if not isinstance(value, ACCEPTED_TYPES):
|
||||
raise TypeError(
|
||||
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}")
|
||||
if isinstance(value, ACCEPTED_INTS):
|
||||
if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max):
|
||||
raise TypeError(
|
||||
"single integer must be in the range [0, 2**64 - 1] (uint64)"
|
||||
)
|
||||
return _LambdaArgument.from_scalar(value)
|
||||
else:
|
||||
assert isinstance(value, np.ndarray)
|
||||
if value.shape == ():
|
||||
return _LambdaArgument.from_scalar(value)
|
||||
if value.dtype not in ACCEPTED_NUMPY_UINTS:
|
||||
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")
|
||||
return _LambdaArgument.from_tensor(value.flatten().tolist(), value.shape)
|
||||
|
||||
|
||||
class CompilerEngine:
|
||||
def __init__(self, mlir_str: str = None):
|
||||
self._engine = JITCompilerSupport()
|
||||
@@ -264,7 +198,8 @@ class ClientSupport:
|
||||
Returns:
|
||||
PublicArguments: the public arguments
|
||||
"""
|
||||
execution_arguments = [create_execution_argument(arg) for arg in args]
|
||||
execution_arguments = [
|
||||
ClientSupport._create_execution_argument(arg) for arg in args]
|
||||
return _ClientSupport.encrypt_arguments(client_parameters, key_set, execution_arguments)
|
||||
|
||||
def decrypt_result(key_set: KeySet, public_result: PublicResult) -> Union[int, np.ndarray]:
|
||||
@@ -287,12 +222,42 @@ class ClientSupport:
|
||||
else:
|
||||
raise RuntimeError("unknown return type")
|
||||
|
||||
def _create_execution_argument(value: Union[int, np.ndarray]) -> _LambdaArgument:
|
||||
"""Create an execution argument holding either an int or tensor value.
|
||||
|
||||
Args:
|
||||
value (Union[int, numpy.array]): value of the argument, either an int, or a numpy array
|
||||
|
||||
Raises:
|
||||
TypeError: if the values aren't in the expected range, or using a wrong type
|
||||
|
||||
Returns:
|
||||
_LambdaArgument: lambda argument holding the appropriate value
|
||||
"""
|
||||
if not isinstance(value, ACCEPTED_TYPES):
|
||||
raise TypeError(
|
||||
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}")
|
||||
if isinstance(value, ACCEPTED_INTS):
|
||||
if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max):
|
||||
raise TypeError(
|
||||
"single integer must be in the range [0, 2**64 - 1] (uint64)"
|
||||
)
|
||||
return _LambdaArgument.from_scalar(value)
|
||||
else:
|
||||
assert isinstance(value, np.ndarray)
|
||||
if value.shape == ():
|
||||
return _LambdaArgument.from_scalar(value)
|
||||
if value.dtype not in ACCEPTED_NUMPY_UINTS:
|
||||
raise TypeError(
|
||||
"numpy.array must be of dtype uint{8,16,32,64}")
|
||||
return _LambdaArgument.from_tensor(value.flatten().tolist(), value.shape)
|
||||
|
||||
|
||||
class JITCompilerSupport:
|
||||
def __init__(self, runtime_lib_path=None):
|
||||
if runtime_lib_path is None:
|
||||
runtime_lib_path = _lookup_runtime_lib()
|
||||
self._support = JITLambdaSupport(runtime_lib_path)
|
||||
self._support = _JITLambdaSupport(runtime_lib_path)
|
||||
|
||||
def compile(self, mlir_program: str, func_name: str = "main") -> JitCompilationResult:
|
||||
"""JIT Compile a function define in the mlir_program to its homomorphic equivalent.
|
||||
@@ -332,7 +297,7 @@ class JITCompilerSupport:
|
||||
class LibraryCompilerSupport:
|
||||
def __init__(self, outputPath="./out"):
|
||||
self._library_path = outputPath
|
||||
self._support = LibraryLambdaSupport(outputPath)
|
||||
self._support = _LibraryLambdaSupport(outputPath)
|
||||
|
||||
def compile(self, mlir_program: str, func_name: str = "main") -> LibraryCompilationResult:
|
||||
"""Compile a function define in the mlir_program to its homomorphic equivalent and save as library.
|
||||
|
||||
@@ -3,7 +3,7 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilerEngine, library
|
||||
from concrete.compiler import CompilerEngine
|
||||
from lib.Bindings.Python.concrete.compiler import JITCompilerSupport, LibraryCompilerSupport
|
||||
from lib.Bindings.Python.concrete.compiler import ClientSupport
|
||||
from lib.Bindings.Python.concrete.compiler import KeySetCache
|
||||
@@ -313,7 +313,7 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
engine.run(*args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result, tab_size",
|
||||
[
|
||||
pytest.param(
|
||||
@@ -338,7 +338,7 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
|
||||
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ pytest.mark.parametrize(
|
||||
"mlir_input",
|
||||
[
|
||||
pytest.param(
|
||||
@@ -359,55 +359,3 @@ def test_compile_invalid(mlir_input):
|
||||
with pytest.raises(RuntimeError, match=r"cannot find the function for generate client parameters"):
|
||||
engine.compile_fhe(
|
||||
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
|
||||
|
||||
MODULE_1 = """
|
||||
func @test1()
|
||||
{
|
||||
return
|
||||
}
|
||||
"""
|
||||
MODULE_2 = """
|
||||
func @test2()
|
||||
{
|
||||
return
|
||||
}
|
||||
"""
|
||||
LIB_PATH = './test_library_generation.so'
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'mlir_modules',
|
||||
[
|
||||
pytest.param(MODULE_1, id='1 module'),
|
||||
pytest.param([MODULE_1, MODULE_2], id='2 modules'),
|
||||
pytest.param(iter([MODULE_1, MODULE_2]), id='iterable'),
|
||||
],
|
||||
)
|
||||
def test_library_generation(mlir_modules):
|
||||
library_path = library(LIB_PATH, mlir_modules)
|
||||
assert os.path.exists(library_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'mlir_modules',
|
||||
[
|
||||
pytest.param(bytes(MODULE_1, encoding='utf-8'), id='bytes vs str'),
|
||||
pytest.param(None, id='not iterable'),
|
||||
pytest.param([None], id='not str'),
|
||||
],
|
||||
)
|
||||
def test_library_generation_type_error(mlir_modules):
|
||||
with pytest.raises(TypeError):
|
||||
library(LIB_PATH, mlir_modules)
|
||||
|
||||
|
||||
def test_library_call():
|
||||
module = """
|
||||
func @test(%a: i8) -> i8
|
||||
{
|
||||
return %a : i8
|
||||
}
|
||||
"""
|
||||
from ctypes import CDLL
|
||||
lib = CDLL(library(LIB_PATH, module))
|
||||
assert lib.test(13) == 13
|
||||
|
||||
Reference in New Issue
Block a user