feat(python): get path to diff artifacts

- path to client parameters file
- path to shared library
This commit is contained in:
youben11
2022-04-29 15:09:33 +01:00
committed by Ayoub Benaissa
parent 211241fcb2
commit 5f1a539505
7 changed files with 79 additions and 4 deletions

View File

@@ -84,6 +84,12 @@ library_server_call(LibrarySupport_C support,
concretelang::serverlib::ServerLambda lambda,
concretelang::clientlib::PublicArguments &args);
MLIR_CAPI_EXPORTED std::string
library_get_shared_lib_path(LibrarySupport_C support);
MLIR_CAPI_EXPORTED std::string
library_get_client_parameters_path(LibrarySupport_C support);
// Client Support bindings ///////////////////////////////////////////////////
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>

View File

@@ -108,6 +108,16 @@ public:
return lambda.call(args);
}
/// Get path to shared library
std::string getSharedLibPath() {
return CompilerEngine::Library::getSharedLibraryPath(outputPath);
}
/// Get path to client parameters file
std::string getClientParametersPath() {
return CompilerEngine::Library::getClientParametersPath(outputPath);
}
private:
std::string outputPath;
std::string runtimeLibraryPath;

View File

@@ -134,7 +134,14 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](LibrarySupport_C &support, serverlib::ServerLambda lambda,
clientlib::PublicArguments &publicArguments) {
return library_server_call(support, lambda, publicArguments);
});
})
.def("get_shared_lib_path",
[](LibrarySupport_C &support) {
return library_get_shared_lib_path(support);
})
.def("get_client_parameters_path", [](LibrarySupport_C &support) {
return library_get_client_parameters_path(support);
});
class ClientSupport {};
pybind11::class_<ClientSupport>(m, "ClientSupport")

View File

@@ -237,3 +237,19 @@ class LibrarySupport(WrapperCpp):
return PublicResult.wrap(
self.cpp().server_call(library_lambda.cpp(), public_arguments.cpp())
)
def get_shared_lib_path(self) -> str:
"""Get the path where the shared library is expected to be.
Returns:
str: path to the shared library
"""
return self.cpp().get_shared_lib_path()
def get_client_parameters_path(self) -> str:
"""Get the path where the client parameters file is expected to be.
Returns:
str: path to the client parameters file
"""
return self.cpp().get_client_parameters_path()

View File

@@ -103,6 +103,16 @@ library_server_call(LibrarySupport_C support,
return std::move(*publicResult);
}
MLIR_CAPI_EXPORTED std::string
library_get_shared_lib_path(LibrarySupport_C support) {
return support.support.getSharedLibPath();
}
MLIR_CAPI_EXPORTED std::string
library_get_client_parameters_path(LibrarySupport_C support) {
return support.support.getClientParametersPath();
}
// Client Support bindings ///////////////////////////////////////////////////
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>

View File

@@ -1,4 +1,6 @@
import pytest
import os.path
import shutil
import numpy as np
from concrete.compiler import (
JITSupport,
@@ -167,19 +169,40 @@ def test_jit_compile_and_run(mlir_input, args, expected_result, keyset_cache):
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_and_run(mlir_input, args, expected_result, keyset_cache):
engine = LibrarySupport.new("./py_test_lib_compile_and_run")
artifact_dir = "./py_test_lib_compile_and_run"
engine = LibrarySupport.new(artifact_dir)
compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache)
shutil.rmtree(artifact_dir)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_reload_and_run(mlir_input, args, expected_result, keyset_cache):
engine = LibrarySupport.new("./test_lib_compile_reload_and_run")
artifact_dir = "./test_lib_compile_reload_and_run"
engine = LibrarySupport.new(artifact_dir)
# Here don't save compilation result, reload
engine.compile(mlir_input)
compilation_result = engine.reload()
result = run(engine, args, compilation_result, keyset_cache)
# Check result
assert_result(result, expected_result)
shutil.rmtree(artifact_dir)
def test_lib_compilation_artifacts():
mlir_str = """
func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> {
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>>
return %res : tensor<4x!FHE.eint<6>>
}
"""
artifact_dir = "./test_artifacts"
engine = LibrarySupport.new(artifact_dir)
engine.compile(mlir_str)
assert os.path.exists(engine.get_client_parameters_path())
assert os.path.exists(engine.get_shared_lib_path())
shutil.rmtree(artifact_dir)
assert not os.path.exists(engine.get_client_parameters_path())
assert not os.path.exists(engine.get_shared_lib_path())
def test_lib_compile_and_run_p_error(keyset_cache):

View File

@@ -1,4 +1,5 @@
import pytest
import shutil
import numpy as np
from concrete.compiler import (
JITSupport,
@@ -143,7 +144,9 @@ def test_jit_compile_and_run_with_serialization(
def test_lib_compile_and_run_with_serialization(
mlir_input, args, expected_result, keyset_cache
):
engine = LibrarySupport.new("./py_test_lib_compile_and_run")
artifact_dir = "./py_test_lib_compile_and_run"
engine = LibrarySupport.new(artifact_dir)
compile_run_assert_with_serialization(
engine, mlir_input, args, expected_result, keyset_cache
)
shutil.rmtree(artifact_dir)