mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(python): get path to diff artifacts
- path to client parameters file - path to shared library
This commit is contained in:
@@ -84,6 +84,12 @@ library_server_call(LibrarySupport_C support,
|
||||
concretelang::serverlib::ServerLambda lambda,
|
||||
concretelang::clientlib::PublicArguments &args);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_shared_lib_path(LibrarySupport_C support);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_client_parameters_path(LibrarySupport_C support);
|
||||
|
||||
// Client Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
|
||||
|
||||
@@ -108,6 +108,16 @@ public:
|
||||
return lambda.call(args);
|
||||
}
|
||||
|
||||
/// Get path to shared library
|
||||
std::string getSharedLibPath() {
|
||||
return CompilerEngine::Library::getSharedLibraryPath(outputPath);
|
||||
}
|
||||
|
||||
/// Get path to client parameters file
|
||||
std::string getClientParametersPath() {
|
||||
return CompilerEngine::Library::getClientParametersPath(outputPath);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string outputPath;
|
||||
std::string runtimeLibraryPath;
|
||||
|
||||
@@ -134,7 +134,14 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
[](LibrarySupport_C &support, serverlib::ServerLambda lambda,
|
||||
clientlib::PublicArguments &publicArguments) {
|
||||
return library_server_call(support, lambda, publicArguments);
|
||||
});
|
||||
})
|
||||
.def("get_shared_lib_path",
|
||||
[](LibrarySupport_C &support) {
|
||||
return library_get_shared_lib_path(support);
|
||||
})
|
||||
.def("get_client_parameters_path", [](LibrarySupport_C &support) {
|
||||
return library_get_client_parameters_path(support);
|
||||
});
|
||||
|
||||
class ClientSupport {};
|
||||
pybind11::class_<ClientSupport>(m, "ClientSupport")
|
||||
|
||||
@@ -237,3 +237,19 @@ class LibrarySupport(WrapperCpp):
|
||||
return PublicResult.wrap(
|
||||
self.cpp().server_call(library_lambda.cpp(), public_arguments.cpp())
|
||||
)
|
||||
|
||||
def get_shared_lib_path(self) -> str:
|
||||
"""Get the path where the shared library is expected to be.
|
||||
|
||||
Returns:
|
||||
str: path to the shared library
|
||||
"""
|
||||
return self.cpp().get_shared_lib_path()
|
||||
|
||||
def get_client_parameters_path(self) -> str:
|
||||
"""Get the path where the client parameters file is expected to be.
|
||||
|
||||
Returns:
|
||||
str: path to the client parameters file
|
||||
"""
|
||||
return self.cpp().get_client_parameters_path()
|
||||
|
||||
@@ -103,6 +103,16 @@ library_server_call(LibrarySupport_C support,
|
||||
return std::move(*publicResult);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_shared_lib_path(LibrarySupport_C support) {
|
||||
return support.support.getSharedLibPath();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_client_parameters_path(LibrarySupport_C support) {
|
||||
return support.support.getClientParametersPath();
|
||||
}
|
||||
|
||||
// Client Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import pytest
|
||||
import os.path
|
||||
import shutil
|
||||
import numpy as np
|
||||
from concrete.compiler import (
|
||||
JITSupport,
|
||||
@@ -167,19 +169,40 @@ def test_jit_compile_and_run(mlir_input, args, expected_result, keyset_cache):
|
||||
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_lib_compile_and_run(mlir_input, args, expected_result, keyset_cache):
|
||||
engine = LibrarySupport.new("./py_test_lib_compile_and_run")
|
||||
artifact_dir = "./py_test_lib_compile_and_run"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache)
|
||||
shutil.rmtree(artifact_dir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_lib_compile_reload_and_run(mlir_input, args, expected_result, keyset_cache):
|
||||
engine = LibrarySupport.new("./test_lib_compile_reload_and_run")
|
||||
artifact_dir = "./test_lib_compile_reload_and_run"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
# Here don't save compilation result, reload
|
||||
engine.compile(mlir_input)
|
||||
compilation_result = engine.reload()
|
||||
result = run(engine, args, compilation_result, keyset_cache)
|
||||
# Check result
|
||||
assert_result(result, expected_result)
|
||||
shutil.rmtree(artifact_dir)
|
||||
|
||||
|
||||
def test_lib_compilation_artifacts():
|
||||
mlir_str = """
|
||||
func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>>
|
||||
return %res : tensor<4x!FHE.eint<6>>
|
||||
}
|
||||
"""
|
||||
artifact_dir = "./test_artifacts"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
engine.compile(mlir_str)
|
||||
assert os.path.exists(engine.get_client_parameters_path())
|
||||
assert os.path.exists(engine.get_shared_lib_path())
|
||||
shutil.rmtree(artifact_dir)
|
||||
assert not os.path.exists(engine.get_client_parameters_path())
|
||||
assert not os.path.exists(engine.get_shared_lib_path())
|
||||
|
||||
|
||||
def test_lib_compile_and_run_p_error(keyset_cache):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
import shutil
|
||||
import numpy as np
|
||||
from concrete.compiler import (
|
||||
JITSupport,
|
||||
@@ -143,7 +144,9 @@ def test_jit_compile_and_run_with_serialization(
|
||||
def test_lib_compile_and_run_with_serialization(
|
||||
mlir_input, args, expected_result, keyset_cache
|
||||
):
|
||||
engine = LibrarySupport.new("./py_test_lib_compile_and_run")
|
||||
artifact_dir = "./py_test_lib_compile_and_run"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
compile_run_assert_with_serialization(
|
||||
engine, mlir_input, args, expected_result, keyset_cache
|
||||
)
|
||||
shutil.rmtree(artifact_dir)
|
||||
|
||||
Reference in New Issue
Block a user