diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index 28012b345..53751d8a0 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -84,6 +84,12 @@ library_server_call(LibrarySupport_C support, concretelang::serverlib::ServerLambda lambda, concretelang::clientlib::PublicArguments &args); +MLIR_CAPI_EXPORTED std::string +library_get_shared_lib_path(LibrarySupport_C support); + +MLIR_CAPI_EXPORTED std::string +library_get_client_parameters_path(LibrarySupport_C support); + // Client Support bindings /////////////////////////////////////////////////// MLIR_CAPI_EXPORTED std::unique_ptr diff --git a/compiler/include/concretelang/Support/LibrarySupport.h b/compiler/include/concretelang/Support/LibrarySupport.h index cc3b14c05..a6d92a9aa 100644 --- a/compiler/include/concretelang/Support/LibrarySupport.h +++ b/compiler/include/concretelang/Support/LibrarySupport.h @@ -108,6 +108,16 @@ public: return lambda.call(args); } + /// Get path to shared library + std::string getSharedLibPath() { + return CompilerEngine::Library::getSharedLibraryPath(outputPath); + } + + /// Get path to client parameters file + std::string getClientParametersPath() { + return CompilerEngine::Library::getClientParametersPath(outputPath); + } + private: std::string outputPath; std::string runtimeLibraryPath; diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 9344bfc5e..461422ead 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -134,7 +134,14 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](LibrarySupport_C &support, serverlib::ServerLambda lambda, clientlib::PublicArguments &publicArguments) { return library_server_call(support, lambda, publicArguments); - }); + }) + .def("get_shared_lib_path", + [](LibrarySupport_C &support) { + return library_get_shared_lib_path(support); + }) + .def("get_client_parameters_path", [](LibrarySupport_C &support) { + return library_get_client_parameters_path(support); + }); class ClientSupport {}; pybind11::class_(m, "ClientSupport") diff --git a/compiler/lib/Bindings/Python/concrete/compiler/library_support.py b/compiler/lib/Bindings/Python/concrete/compiler/library_support.py index a08d84603..678981acc 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/library_support.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/library_support.py @@ -237,3 +237,19 @@ class LibrarySupport(WrapperCpp): return PublicResult.wrap( self.cpp().server_call(library_lambda.cpp(), public_arguments.cpp()) ) + + def get_shared_lib_path(self) -> str: + """Get the path where the shared library is expected to be. + + Returns: + str: path to the shared library + """ + return self.cpp().get_shared_lib_path() + + def get_client_parameters_path(self) -> str: + """Get the path where the client parameters file is expected to be. + + Returns: + str: path to the client parameters file + """ + return self.cpp().get_client_parameters_path() diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index e45f7132b..1d3769bbb 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -103,6 +103,16 @@ library_server_call(LibrarySupport_C support, return std::move(*publicResult); } +MLIR_CAPI_EXPORTED std::string +library_get_shared_lib_path(LibrarySupport_C support) { + return support.support.getSharedLibPath(); +} + +MLIR_CAPI_EXPORTED std::string +library_get_client_parameters_path(LibrarySupport_C support) { + return support.support.getClientParametersPath(); +} + // Client Support bindings /////////////////////////////////////////////////// MLIR_CAPI_EXPORTED std::unique_ptr diff --git a/compiler/tests/python/test_compilation.py b/compiler/tests/python/test_compilation.py index 50bd7aa16..cb2bec580 100644 --- a/compiler/tests/python/test_compilation.py +++ b/compiler/tests/python/test_compilation.py @@ -1,4 +1,6 @@ import pytest +import os.path +import shutil import numpy as np from concrete.compiler import ( JITSupport, @@ -167,19 +169,40 @@ def test_jit_compile_and_run(mlir_input, args, expected_result, keyset_cache): @pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) def test_lib_compile_and_run(mlir_input, args, expected_result, keyset_cache): - engine = LibrarySupport.new("./py_test_lib_compile_and_run") + artifact_dir = "./py_test_lib_compile_and_run" + engine = LibrarySupport.new(artifact_dir) compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache) + shutil.rmtree(artifact_dir) @pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) def test_lib_compile_reload_and_run(mlir_input, args, expected_result, keyset_cache): - engine = LibrarySupport.new("./test_lib_compile_reload_and_run") + artifact_dir = "./test_lib_compile_reload_and_run" + engine = LibrarySupport.new(artifact_dir) # Here don't save compilation result, reload engine.compile(mlir_input) compilation_result = engine.reload() result = run(engine, args, compilation_result, keyset_cache) # Check result assert_result(result, expected_result) + shutil.rmtree(artifact_dir) + + +def test_lib_compilation_artifacts(): + mlir_str = """ + func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> { + %res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>> + return %res : tensor<4x!FHE.eint<6>> + } + """ + artifact_dir = "./test_artifacts" + engine = LibrarySupport.new(artifact_dir) + engine.compile(mlir_str) + assert os.path.exists(engine.get_client_parameters_path()) + assert os.path.exists(engine.get_shared_lib_path()) + shutil.rmtree(artifact_dir) + assert not os.path.exists(engine.get_client_parameters_path()) + assert not os.path.exists(engine.get_shared_lib_path()) def test_lib_compile_and_run_p_error(keyset_cache): diff --git a/compiler/tests/python/test_serialization.py b/compiler/tests/python/test_serialization.py index cc543e149..ba634e32e 100644 --- a/compiler/tests/python/test_serialization.py +++ b/compiler/tests/python/test_serialization.py @@ -1,4 +1,5 @@ import pytest +import shutil import numpy as np from concrete.compiler import ( JITSupport, @@ -143,7 +144,9 @@ def test_jit_compile_and_run_with_serialization( def test_lib_compile_and_run_with_serialization( mlir_input, args, expected_result, keyset_cache ): - engine = LibrarySupport.new("./py_test_lib_compile_and_run") + artifact_dir = "./py_test_lib_compile_and_run" + engine = LibrarySupport.new(artifact_dir) compile_run_assert_with_serialization( engine, mlir_input, args, expected_result, keyset_cache ) + shutil.rmtree(artifact_dir)