fix: add a wrapper to compiled circuits to unify invocation

this was already implemented for JIT using mlir::ExecutionEngine, but
was using a different, and more complex way for library compilation and
execution, which was causing a bad calling convention at the assembly
level in MacOS M1 machine. This commits unify the invocation of JIT and
Library compiled circuit, solving the previously mentioned issue, but
also gives the ability to extend compiled libraries to support more than one
returned value
This commit is contained in:
youben11
2023-03-03 16:34:10 +01:00
committed by Ayoub Benaissa
parent b1a94ac245
commit dc8b762708
16 changed files with 275 additions and 1994 deletions

View File

@@ -1,27 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SERVERLIB_DYNAMIC_RANK_CALL_H
#define CONCRETELANG_SERVERLIB_DYNAMIC_RANK_CALL_H
#include <vector>
#include "concretelang/ClientLib/Types.h"
namespace concretelang {
namespace serverlib {
using concretelang::clientlib::ScalarOrTensorData;
ScalarOrTensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args,
size_t rank,
size_t element_width,
bool is_signed);
} // namespace serverlib
} // namespace concretelang
#endif

View File

@@ -15,6 +15,7 @@
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Common/Error.h"
#include "concretelang/ServerLib/DynamicModule.h"
#include "concretelang/Support/Error.h"
namespace concretelang {
namespace serverlib {
@@ -36,13 +37,21 @@ public:
loadFromModule(std::shared_ptr<DynamicModule> module, std::string funcName);
/// Call the ServerLambda with public arguments.
std::unique_ptr<clientlib::PublicResult>
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
call(clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys);
/// \brief Call the loaded function using opaque pointers to both inputs and
/// outputs.
/// \param args Array containing pointers to inputs first, followed by
/// pointers to outputs.
/// \return Error if failed, success otherwise.
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
protected:
ClientParameters clientParameters;
void *(*func)(void *...);
/// holds a pointer to the entrypoint of the shared lib which
void (*func)(void *...);
/// Retain module and open shared lib alive
std::shared_ptr<DynamicModule> module;
};

View File

@@ -1,45 +0,0 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
print(
"""// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/main/LICENSE.txt
// for license information.
// generated: see genDynamicArityCall.py
#ifndef CONCRETELANG_SERVERLIB_DYNAMIC_ARITY_CALL_H
#define CONCRETELANG_SERVERLIB_DYNAMIC_ARITY_CALL_H
#include <cassert>
#include <vector>
#include "concretelang/ClientLib/Types.h"
namespace mlir {
namespace serverlib {
template <typename Res>
Res multi_arity_call(Res (*func)(void *...), std::vector<void *> args) {
switch (args.size()) {
// TODO C17++: https://en.cppreference.com/w/cpp/utility/apply
""")
for i in range(1, 128):
args = ','.join(f'args[{j}]' for j in range(i))
print(f' case {i}: return func({args});')
print("""
default:
assert(false);
}
}""")
print("""
} // namespace concretelang
} // namespace mlir
#endif
""")

View File

@@ -41,7 +41,6 @@ public:
void setUseDataflow(bool option) { this->useDataflow = option; }
private:
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
/// used to store the result of the computation.
/// Example:

View File

@@ -6,6 +6,8 @@
#ifndef CONCRETELANG_SUPPORT_LLVMEMITFILE
#define CONCRETELANG_SUPPORT_LLVMEMITFILE
#include <llvm/ADT/StringRef.h>
namespace mlir {
namespace concretelang {

View File

@@ -0,0 +1,114 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_UTILS_H_
#define CONCRETELANG_SUPPORT_UTILS_H_
#include <concretelang/ClientLib/ClientLambda.h>
#include <concretelang/ClientLib/KeySet.h>
#include <concretelang/ClientLib/PublicArguments.h>
#include <concretelang/ClientLib/Serializers.h>
#include <concretelang/Runtime/context.h>
#include <concretelang/ServerLib/ServerLambda.h>
#include <concretelang/Support/Error.h>
namespace concretelang {
// construct the function name of the wrapper function that unify function calls
// of compiled circuit
std::string makePackedFunctionName(llvm::StringRef name);
// memref is a struct which is flattened aligned, allocated pointers, offset,
// and two array of rank size for sizes and strides.
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank);
template <typename Lambda>
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters,
std::vector<void *> preparedInputArgs,
clientlib::EvaluationKeys &evaluationKeys) {
// invokeRaw needs to have pointers on arguments and a pointers on the result
// as last argument.
// Prepare the outputs vector to store the output value of the lambda.
auto numOutputs = 0;
for (auto &output : clientParameters.outputs) {
auto shape = clientParameters.bufferShape(output);
if (shape.size() == 0) {
// scalar gate
numOutputs += 1;
} else {
// buffer gate
numOutputs += numArgOfRankedMemrefCallingConvention(shape.size());
}
}
std::vector<uint64_t> outputs(numOutputs);
// Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on
// inputs and outputs.
std::vector<void *> rawArgs(
preparedInputArgs.size() + 1 /*runtime context*/ + 1 /* outputs */
);
size_t i = 0;
// Pointers on inputs
for (auto &arg : preparedInputArgs) {
rawArgs[i++] = &arg;
}
mlir::concretelang::RuntimeContext runtimeContext(evaluationKeys);
// Pointer on runtime context, the rawArgs take pointer on actual value that
// is passed to the compiled function.
auto rtCtxPtr = &runtimeContext;
rawArgs[i++] = &rtCtxPtr;
// Outputs
rawArgs[i++] = reinterpret_cast<void *>(outputs.data());
// Invoke
if (auto err = lambda->invokeRaw(rawArgs)) {
return std::move(err);
}
// Store the result to the PublicResult
std::vector<clientlib::ScalarOrTensorData> buffers;
{
size_t outputOffset = 0;
for (auto &output : clientParameters.outputs) {
auto shape = clientParameters.bufferShape(output);
if (shape.size() == 0) {
// scalar scalar
buffers.push_back(concretelang::clientlib::ScalarOrTensorData(
concretelang::clientlib::ScalarData(outputs[outputOffset++],
output.shape.sign,
output.shape.width)));
} else {
// buffer gate
auto rank = shape.size();
auto allocated = (uint64_t *)outputs[outputOffset++];
auto aligned = (uint64_t *)outputs[outputOffset++];
auto offset = (size_t)outputs[outputOffset++];
size_t *sizes = (size_t *)&outputs[outputOffset];
outputOffset += rank;
size_t *strides = (size_t *)&outputs[outputOffset];
outputOffset += rank;
size_t elementWidth = (output.isEncrypted())
? clientlib::EncryptedScalarElementWidth
: output.shape.width;
bool sign = (output.isEncrypted()) ? false : output.shape.sign;
concretelang::clientlib::TensorData td =
clientlib::tensorDataFromMemRef(rank, elementWidth, sign, allocated,
aligned, offset, sizes, strides);
buffers.push_back(
concretelang::clientlib::ScalarOrTensorData(std::move(td)));
}
}
}
return clientlib::PublicResult::fromBuffers(clientParameters,
std::move(buffers));
}
} // namespace concretelang
#endif

View File

@@ -90,9 +90,12 @@ public:
// server function call
auto evaluationKeys = keySet->evaluationKeys();
auto publicResult = serverLambda.call(*publicArgument, evaluationKeys);
if (!publicResult) {
return StringError("failed calling function");
}
// client result decryption
return this->decryptResult(*keySet, *publicResult);
return this->decryptResult(*keySet, *(publicResult.get()));
}
private:

View File

@@ -7,7 +7,6 @@ endif()
add_mlir_library(
ConcretelangServerLib
DynamicRankCall.cpp
ServerLambda.cpp
DynamicModule.cpp
ADDITIONAL_HEADER_DIRS

View File

@@ -1,247 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
// generated: see genDynamicRandCall.py
#include <cassert>
#include <vector>
#include "concretelang/ClientLib/Types.h"
#include "concretelang/ServerLib/DynamicArityCall.h"
#include "concretelang/ServerLib/ServerLambda.h"
namespace concretelang {
namespace serverlib {
/// Helper class template that yields an unsigned integer type given a
/// size in bytes
template <std::size_t size> struct int_type_of_size {};
template <> struct int_type_of_size<4> { typedef uint32_t type; };
template <> struct int_type_of_size<8> { typedef uint64_t type; };
/// Converts one function pointer into another
// TODO: Not sure this is valid in all implementations / on all
// architectures
template <typename FnDstT, typename FnSrcT> FnDstT convert_fnptr(FnSrcT src) {
static_assert(sizeof(FnDstT) == sizeof(FnSrcT),
"Size of function types must match");
using inttype = typename int_type_of_size<sizeof(FnDstT)>::type;
inttype raw = reinterpret_cast<inttype>(src);
return reinterpret_cast<FnDstT>(raw);
}
ScalarOrTensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args,
size_t rank,
size_t element_width,
bool is_signed) {
using concretelang::clientlib::MemRefDescriptor;
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
switch (rank) {
case 0: {
auto m =
multi_arity_call(convert_fnptr<uint64_t (*)(void *...)>(func), args);
return concretelang::clientlib::ScalarData(m, is_signed, element_width);
}
case 1: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<1> (*)(void *...)>(func), args);
return convert(1, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 2: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<2> (*)(void *...)>(func), args);
return convert(2, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 3: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<3> (*)(void *...)>(func), args);
return convert(3, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 4: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<4> (*)(void *...)>(func), args);
return convert(4, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 5: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<5> (*)(void *...)>(func), args);
return convert(5, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 6: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<6> (*)(void *...)>(func), args);
return convert(6, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 7: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<7> (*)(void *...)>(func), args);
return convert(7, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 8: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<8> (*)(void *...)>(func), args);
return convert(8, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 9: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<9> (*)(void *...)>(func), args);
return convert(9, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 10: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<10> (*)(void *...)>(func), args);
return convert(10, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 11: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<11> (*)(void *...)>(func), args);
return convert(11, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 12: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<12> (*)(void *...)>(func), args);
return convert(12, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 13: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<13> (*)(void *...)>(func), args);
return convert(13, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 14: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<14> (*)(void *...)>(func), args);
return convert(14, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 15: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<15> (*)(void *...)>(func), args);
return convert(15, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 16: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<16> (*)(void *...)>(func), args);
return convert(16, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 17: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<17> (*)(void *...)>(func), args);
return convert(17, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 18: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<18> (*)(void *...)>(func), args);
return convert(18, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 19: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<19> (*)(void *...)>(func), args);
return convert(19, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 20: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<20> (*)(void *...)>(func), args);
return convert(20, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 21: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<21> (*)(void *...)>(func), args);
return convert(21, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 22: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<22> (*)(void *...)>(func), args);
return convert(22, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 23: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<23> (*)(void *...)>(func), args);
return convert(23, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 24: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<24> (*)(void *...)>(func), args);
return convert(24, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 25: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<25> (*)(void *...)>(func), args);
return convert(25, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 26: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<26> (*)(void *...)>(func), args);
return convert(26, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 27: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<27> (*)(void *...)>(func), args);
return convert(27, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 28: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<28> (*)(void *...)>(func), args);
return convert(28, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 29: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<29> (*)(void *...)>(func), args);
return convert(29, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 30: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<30> (*)(void *...)>(func), args);
return convert(30, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 31: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<31> (*)(void *...)>(func), args);
return convert(31, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 32: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<32> (*)(void *...)>(func), args);
return convert(32, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
default:
assert(false);
}
}
} // namespace serverlib
} // namespace concretelang

View File

@@ -9,13 +9,10 @@
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/context.h"
#include "concretelang/ServerLib/DynamicArityCall.h"
#include "concretelang/ServerLib/DynamicModule.h"
#include "concretelang/ServerLib/DynamicRankCall.h"
#include "concretelang/ServerLib/ServerLambda.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/Utils.h"
namespace concretelang {
namespace serverlib {
@@ -25,16 +22,17 @@ using concretelang::clientlib::CircuitGateShape;
using concretelang::clientlib::EvaluationKeys;
using concretelang::clientlib::PublicArguments;
using concretelang::error::StringError;
using mlir::concretelang::RuntimeContext;
using mlir::concretelang::StreamStringError;
outcome::checked<ServerLambda, StringError>
ServerLambda::loadFromModule(std::shared_ptr<DynamicModule> module,
std::string funcName) {
auto packedFuncName = ::concretelang::makePackedFunctionName(funcName);
ServerLambda lambda;
lambda.module =
module; // prevent module and library handler from being destroyed
lambda.func =
(void *(*)(void *, ...))dlsym(module->libraryHandle, funcName.c_str());
lambda.func = (void (*)(void *, ...))dlsym(module->libraryHandle,
packedFuncName.c_str());
if (auto err = dlerror()) {
return StringError("Cannot open lambda:") << std::string(err);
@@ -66,29 +64,22 @@ ServerLambda::load(std::string funcName, std::string outputPath) {
return ServerLambda::loadFromModule(module, funcName);
}
std::unique_ptr<clientlib::PublicResult>
llvm::Error ServerLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
auto found = std::find(args.begin(), args.end(), nullptr);
if (found == args.end()) {
assert(func != nullptr && "func pointer shouldn't be null");
func(args.data());
return llvm::Error::success();
}
int pos = found - args.begin();
return StreamStringError("invoke: argument at pos ")
<< pos << " is null or missing";
}
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
std::vector<void *> preparedArgs(args.preparedArgs.begin(),
args.preparedArgs.end());
RuntimeContext runtimeContext(evaluationKeys);
preparedArgs.push_back((void *)&runtimeContext);
assert(clientParameters.outputs.size() == 1 &&
"ServerLambda::call is implemented for only one output");
auto output = args.clientParameters.outputs[0];
auto rank = args.clientParameters.bufferShape(output).size();
size_t element_width = (output.isEncrypted()) ? 64 : output.shape.width;
bool sign = (output.isEncrypted()) ? false : output.shape.sign;
auto result = multi_arity_call_dynamic_rank(func, preparedArgs, rank,
element_width, sign);
std::vector<ScalarOrTensorData> results;
results.push_back(std::move(result));
return clientlib::PublicResult::fromBuffers(clientParameters,
std::move(results));
return invokeRawOnLambda(this, args.clientParameters, args.preparedArgs,
evaluationKeys);
}
} // namespace serverlib

View File

@@ -1,68 +0,0 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
print(
"""// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
// generated: see genDynamicRandCall.py
#include <cassert>
#include <vector>
#include "concretelang/ClientLib/Types.h"
#include "concretelang/ServerLib/DynamicArityCall.h"
#include "concretelang/ServerLib/ServerLambda.h"
namespace concretelang {
namespace serverlib {
/// Helper class template that yields an unsigned integer type given a
/// size in bytes
template <std::size_t size> struct int_type_of_size {};
template <> struct int_type_of_size<4> { typedef uint32_t type; };
template <> struct int_type_of_size<8> { typedef uint64_t type; };
/// Converts one function pointer into another
// TODO: Not sure this is valid in all implementations / on all
// architectures
template <typename FnDstT, typename FnSrcT> FnDstT convert_fnptr(FnSrcT src) {
static_assert(sizeof(FnDstT) == sizeof(FnSrcT),
"Size of function types must match");
using inttype = typename int_type_of_size<sizeof(FnDstT)>::type;
inttype raw = reinterpret_cast<inttype>(src);
return reinterpret_cast<FnDstT>(raw);
}
ScalarOrTensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args, size_t rank,
size_t element_width, bool is_signed) {
using concretelang::clientlib::MemRefDescriptor;
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
switch (rank) {
case 0: {
auto m =
multi_arity_call(convert_fnptr<uint64_t (*)(void *...)>(func), args);
return concretelang::clientlib::ScalarData(m, is_signed, element_width);
}""")
for tensor_rank in range(1, 33):
memref_rank = tensor_rank
print(f""" case {tensor_rank}: {{
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<{memref_rank}> (*)(void *...)>(func), args);
return convert({memref_rank}, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}}""")
print("""
default:
assert(false);
}
}""")
print("""
} // namespace serverlib
} // namespace concretelang""")

View File

@@ -11,6 +11,7 @@ add_mlir_library(
logging.cpp
Jit.cpp
LLVMEmitFile.cpp
Utils.cpp
DEPENDS
mlir-headers
LINK_LIBS

View File

@@ -14,10 +14,10 @@
#include "concretelang/Common/BitsSize.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/logging.h"
#include <concretelang/Support/Utils.h>
namespace mlir {
namespace concretelang {
@@ -76,12 +76,6 @@ llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
<< pos << " is null or missing";
}
// memref is a struct which is flattened aligned, allocated pointers, offset,
// and two array of rank size for sizes and strides.
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) {
return 3 + 2 * rank;
}
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
JITLambda::call(clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys) {
@@ -107,85 +101,8 @@ JITLambda::call(clientlib::PublicArguments &args,
}
#endif
// invokeRaw needs to have pointers on arguments and a pointers on the result
// as last argument.
// Prepare the outputs vector to store the output value of the lambda.
auto numOutputs = 0;
for (auto &output : args.clientParameters.outputs) {
auto shape = args.clientParameters.bufferShape(output);
if (shape.size() == 0) {
// scalar gate
numOutputs += 1;
} else {
// buffer gate
numOutputs += numArgOfRankedMemrefCallingConvention(shape.size());
}
}
std::vector<uint64_t> outputs(numOutputs);
// Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on
// inputs and outputs.
std::vector<void *> rawArgs(
args.preparedArgs.size() + 1 /*runtime context*/ + 1 /* outputs */
);
size_t i = 0;
// Pointers on inputs
for (auto &arg : args.preparedArgs) {
rawArgs[i++] = &arg;
}
mlir::concretelang::RuntimeContext runtimeContext(evaluationKeys);
// Pointer on runtime context, the rawArgs take pointer on actual value that
// is passed to the compiled function.
auto rtCtxPtr = &runtimeContext;
rawArgs[i++] = &rtCtxPtr;
// Outputs
rawArgs[i++] = reinterpret_cast<void *>(outputs.data());
// Invoke
if (auto err = invokeRaw(rawArgs)) {
return std::move(err);
}
// Store the result to the PublicResult
std::vector<clientlib::ScalarOrTensorData> buffers;
{
size_t outputOffset = 0;
for (auto &output : args.clientParameters.outputs) {
auto shape = args.clientParameters.bufferShape(output);
if (shape.size() == 0) {
// scalar scalar
buffers.push_back(concretelang::clientlib::ScalarOrTensorData(
concretelang::clientlib::ScalarData(outputs[outputOffset++],
output.shape.sign,
output.shape.width)));
} else {
// buffer gate
auto rank = shape.size();
auto allocated = (uint64_t *)outputs[outputOffset++];
auto aligned = (uint64_t *)outputs[outputOffset++];
auto offset = (size_t)outputs[outputOffset++];
size_t *sizes = (size_t *)&outputs[outputOffset];
outputOffset += rank;
size_t *strides = (size_t *)&outputs[outputOffset];
outputOffset += rank;
size_t elementWidth = (output.isEncrypted())
? clientlib::EncryptedScalarElementWidth
: output.shape.width;
bool sign = (output.isEncrypted()) ? false : output.shape.sign;
concretelang::clientlib::TensorData td =
clientlib::tensorDataFromMemRef(rank, elementWidth, sign, allocated,
aligned, offset, sizes, strides);
buffers.push_back(
concretelang::clientlib::ScalarOrTensorData(std::move(td)));
}
}
}
return clientlib::PublicResult::fromBuffers(args.clientParameters,
std::move(buffers));
return ::concretelang::invokeRawOnLambda(this, args.clientParameters,
args.preparedArgs, evaluationKeys);
}
} // namespace concretelang

View File

@@ -5,6 +5,8 @@
#include <errno.h>
#include "llvm/MC/SubtargetFeature.h"
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/MC/TargetRegistry.h>
#include <llvm/Support/ToolOutputFile.h>
@@ -15,6 +17,7 @@
#include <mlir/Support/FileUtilities.h>
#include <concretelang/Support/Error.h>
#include <concretelang/Support/Utils.h>
namespace mlir {
namespace concretelang {
@@ -22,30 +25,113 @@ namespace concretelang {
using std::string;
using std::vector;
llvm::TargetMachine *getDefaultTargetMachine() {
auto TargetTriple = llvm::sys::getDefaultTargetTriple();
string Error;
auto Target = llvm::TargetRegistry::lookupTarget(TargetTriple, Error);
if (!Target) {
// Get target machine from current machine and setup LLVM module accordingly
std::unique_ptr<llvm::TargetMachine>
getTargetMachineAndSetupModule(llvm::Module *llvmModule) {
// Setup the machine properties from the current architecture.
auto targetTriple = llvm::sys::getDefaultTargetTriple();
std::string errorMessage;
const auto *target =
llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage);
if (!target) {
llvm::errs() << "NO target: " << errorMessage << "\n";
return nullptr;
}
auto CPU = "generic";
auto Features = "";
llvm::TargetOptions opt;
return Target->createTargetMachine(TargetTriple, CPU, Features, opt,
llvm::Reloc::PIC_);
std::string cpu(llvm::sys::getHostCPUName());
llvm::SubtargetFeatures features;
llvm::StringMap<bool> hostFeatures;
if (llvm::sys::getHostCPUFeatures(hostFeatures))
for (auto &f : hostFeatures)
features.AddFeature(f.first(), f.second);
std::unique_ptr<llvm::TargetMachine> machine(target->createTargetMachine(
targetTriple, cpu, features.getString(), {}, llvm::Reloc::PIC_));
if (!machine) {
llvm::errs() << "Unable to create target machine\n";
return nullptr;
}
llvmModule->setDataLayout(machine->createDataLayout());
llvmModule->setTargetTriple(targetTriple);
return machine;
}
// This function was copied from the MLIR Execution Engine, and provide an
// elegant and generic invocation interface to the compiled circuit:
// For each function in the LLVM module, define an interface function that wraps
// all the arguments of the original function and all its results into an i8**
// pointer to provide a unified invocation interface.
static void packFunctionArguments(llvm::Module *module) {
auto &ctx = module->getContext();
llvm::IRBuilder<> builder(ctx);
llvm::DenseSet<llvm::Function *> interfaceFunctions;
for (auto &func : module->getFunctionList()) {
if (func.isDeclaration()) {
continue;
}
if (interfaceFunctions.count(&func)) {
continue;
}
// Given a function `foo(<...>)`, define the interface function
// `mlir_foo(i8**)`.
auto *newType = llvm::FunctionType::get(
builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
/*isVarArg=*/false);
auto newName = ::concretelang::makePackedFunctionName(func.getName());
auto funcCst = module->getOrInsertFunction(newName, newType);
llvm::Function *interfaceFunc =
llvm::cast<llvm::Function>(funcCst.getCallee());
interfaceFunctions.insert(interfaceFunc);
// Extract the arguments from the type-erased argument list and cast them to
// the proper types.
auto *bb = llvm::BasicBlock::Create(ctx);
bb->insertInto(interfaceFunc);
builder.SetInsertPoint(bb);
llvm::Value *argList = interfaceFunc->arg_begin();
llvm::SmallVector<llvm::Value *, 8> args;
args.reserve(llvm::size(func.args()));
for (auto &indexedArg : llvm::enumerate(func.args())) {
llvm::Value *argIndex = llvm::Constant::getIntegerValue(
builder.getInt64Ty(), llvm::APInt(64, indexedArg.index()));
llvm::Value *argPtrPtr =
builder.CreateGEP(builder.getInt8PtrTy(), argList, argIndex);
llvm::Value *argPtr =
builder.CreateLoad(builder.getInt8PtrTy(), argPtrPtr);
llvm::Type *argTy = indexedArg.value().getType();
argPtr = builder.CreateBitCast(argPtr, argTy->getPointerTo());
llvm::Value *arg = builder.CreateLoad(argTy, argPtr);
args.push_back(arg);
}
// Call the implementation function with the extracted arguments.
llvm::Value *result = builder.CreateCall(&func, args);
// Assuming the result is one value, potentially of type `void`.
if (!result->getType()->isVoidTy()) {
llvm::Value *retIndex = llvm::Constant::getIntegerValue(
builder.getInt64Ty(), llvm::APInt(64, llvm::size(func.args())));
llvm::Value *retPtrPtr =
builder.CreateGEP(builder.getInt8PtrTy(), argList, retIndex);
llvm::Value *retPtr =
builder.CreateLoad(builder.getInt8PtrTy(), retPtrPtr);
retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo());
builder.CreateStore(result, retPtr);
}
// The interface function returns void.
builder.CreateRetVoid();
}
}
llvm::Error emitObject(llvm::Module &module, string objectPath) {
auto targetMachine = getDefaultTargetMachine();
auto targetMachine = getTargetMachineAndSetupModule(&module);
if (!targetMachine) {
return StreamStringError("No default target machine for object generation");
}
module.setDataLayout(targetMachine->createDataLayout());
string Error;
std::unique_ptr<llvm::ToolOutputFile> objectFile =
mlir::openOutputFile(objectPath, &Error);
@@ -53,6 +139,8 @@ llvm::Error emitObject(llvm::Module &module, string objectPath) {
return StreamStringError("Cannot create/open " + objectPath);
}
packFunctionArguments(&module);
// The legacy PassManager is mandatory for final code generation.
// https://llvm.org/docs/NewPassManager.html#status-of-the-new-and-legacy-pass-managers
llvm::legacy::PassManager pm;
@@ -67,7 +155,6 @@ llvm::Error emitObject(llvm::Module &module, string objectPath) {
objectFile->os().flush();
objectFile->os().close();
objectFile->keep();
delete targetMachine;
return llvm::Error::success();
}

View File

@@ -0,0 +1,18 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <concretelang/Support/Utils.h>
namespace concretelang {
std::string makePackedFunctionName(llvm::StringRef name) {
return "_mlir_" + name.str();
}
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) {
return 3 + 2 * rank;
}
} // namespace concretelang