mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
fix: add a wrapper to compiled circuits to unify invocation
this was already implemented for JIT using mlir::ExecutionEngine, but was using a different, and more complex way for library compilation and execution, which was causing a bad calling convention at the assembly level in MacOS M1 machine. This commits unify the invocation of JIT and Library compiled circuit, solving the previously mentioned issue, but also gives the ability to extend compiled libraries to support more than one returned value
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -1,27 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SERVERLIB_DYNAMIC_RANK_CALL_H
|
||||
#define CONCRETELANG_SERVERLIB_DYNAMIC_RANK_CALL_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
using concretelang::clientlib::ScalarOrTensorData;
|
||||
|
||||
ScalarOrTensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
|
||||
std::vector<void *> args,
|
||||
size_t rank,
|
||||
size_t element_width,
|
||||
bool is_signed);
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/ServerLib/DynamicModule.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
@@ -36,13 +37,21 @@ public:
|
||||
loadFromModule(std::shared_ptr<DynamicModule> module, std::string funcName);
|
||||
|
||||
/// Call the ServerLambda with public arguments.
|
||||
std::unique_ptr<clientlib::PublicResult>
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
call(clientlib::PublicArguments &args,
|
||||
clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
/// \brief Call the loaded function using opaque pointers to both inputs and
|
||||
/// outputs.
|
||||
/// \param args Array containing pointers to inputs first, followed by
|
||||
/// pointers to outputs.
|
||||
/// \return Error if failed, success otherwise.
|
||||
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
|
||||
|
||||
protected:
|
||||
ClientParameters clientParameters;
|
||||
void *(*func)(void *...);
|
||||
/// holds a pointer to the entrypoint of the shared lib which
|
||||
void (*func)(void *...);
|
||||
/// Retain module and open shared lib alive
|
||||
std::shared_ptr<DynamicModule> module;
|
||||
};
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
|
||||
|
||||
print(
|
||||
"""// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
// generated: see genDynamicArityCall.py
|
||||
|
||||
#ifndef CONCRETELANG_SERVERLIB_DYNAMIC_ARITY_CALL_H
|
||||
#define CONCRETELANG_SERVERLIB_DYNAMIC_ARITY_CALL_H
|
||||
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace serverlib {
|
||||
|
||||
|
||||
template <typename Res>
|
||||
Res multi_arity_call(Res (*func)(void *...), std::vector<void *> args) {
|
||||
switch (args.size()) {
|
||||
// TODO C17++: https://en.cppreference.com/w/cpp/utility/apply
|
||||
""")
|
||||
|
||||
for i in range(1, 128):
|
||||
args = ','.join(f'args[{j}]' for j in range(i))
|
||||
print(f' case {i}: return func({args});')
|
||||
print("""
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}""")
|
||||
|
||||
print("""
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
""")
|
||||
@@ -41,7 +41,6 @@ public:
|
||||
|
||||
void setUseDataflow(bool option) { this->useDataflow = option; }
|
||||
|
||||
private:
|
||||
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
|
||||
/// used to store the result of the computation.
|
||||
/// Example:
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#ifndef CONCRETELANG_SUPPORT_LLVMEMITFILE
|
||||
#define CONCRETELANG_SUPPORT_LLVMEMITFILE
|
||||
|
||||
#include <llvm/ADT/StringRef.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_UTILS_H_
|
||||
#define CONCRETELANG_SUPPORT_UTILS_H_
|
||||
|
||||
#include <concretelang/ClientLib/ClientLambda.h>
|
||||
#include <concretelang/ClientLib/KeySet.h>
|
||||
#include <concretelang/ClientLib/PublicArguments.h>
|
||||
#include <concretelang/ClientLib/Serializers.h>
|
||||
#include <concretelang/Runtime/context.h>
|
||||
#include <concretelang/ServerLib/ServerLambda.h>
|
||||
#include <concretelang/Support/Error.h>
|
||||
|
||||
namespace concretelang {
|
||||
|
||||
// construct the function name of the wrapper function that unify function calls
|
||||
// of compiled circuit
|
||||
std::string makePackedFunctionName(llvm::StringRef name);
|
||||
|
||||
// memref is a struct which is flattened aligned, allocated pointers, offset,
|
||||
// and two array of rank size for sizes and strides.
|
||||
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank);
|
||||
|
||||
template <typename Lambda>
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters,
|
||||
std::vector<void *> preparedInputArgs,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
// invokeRaw needs to have pointers on arguments and a pointers on the result
|
||||
// as last argument.
|
||||
// Prepare the outputs vector to store the output value of the lambda.
|
||||
auto numOutputs = 0;
|
||||
for (auto &output : clientParameters.outputs) {
|
||||
auto shape = clientParameters.bufferShape(output);
|
||||
if (shape.size() == 0) {
|
||||
// scalar gate
|
||||
numOutputs += 1;
|
||||
} else {
|
||||
// buffer gate
|
||||
numOutputs += numArgOfRankedMemrefCallingConvention(shape.size());
|
||||
}
|
||||
}
|
||||
std::vector<uint64_t> outputs(numOutputs);
|
||||
|
||||
// Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on
|
||||
// inputs and outputs.
|
||||
std::vector<void *> rawArgs(
|
||||
preparedInputArgs.size() + 1 /*runtime context*/ + 1 /* outputs */
|
||||
);
|
||||
size_t i = 0;
|
||||
// Pointers on inputs
|
||||
for (auto &arg : preparedInputArgs) {
|
||||
rawArgs[i++] = &arg;
|
||||
}
|
||||
|
||||
mlir::concretelang::RuntimeContext runtimeContext(evaluationKeys);
|
||||
// Pointer on runtime context, the rawArgs take pointer on actual value that
|
||||
// is passed to the compiled function.
|
||||
auto rtCtxPtr = &runtimeContext;
|
||||
rawArgs[i++] = &rtCtxPtr;
|
||||
|
||||
// Outputs
|
||||
rawArgs[i++] = reinterpret_cast<void *>(outputs.data());
|
||||
|
||||
// Invoke
|
||||
if (auto err = lambda->invokeRaw(rawArgs)) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
// Store the result to the PublicResult
|
||||
std::vector<clientlib::ScalarOrTensorData> buffers;
|
||||
{
|
||||
size_t outputOffset = 0;
|
||||
for (auto &output : clientParameters.outputs) {
|
||||
auto shape = clientParameters.bufferShape(output);
|
||||
if (shape.size() == 0) {
|
||||
// scalar scalar
|
||||
buffers.push_back(concretelang::clientlib::ScalarOrTensorData(
|
||||
concretelang::clientlib::ScalarData(outputs[outputOffset++],
|
||||
output.shape.sign,
|
||||
output.shape.width)));
|
||||
} else {
|
||||
// buffer gate
|
||||
auto rank = shape.size();
|
||||
auto allocated = (uint64_t *)outputs[outputOffset++];
|
||||
auto aligned = (uint64_t *)outputs[outputOffset++];
|
||||
auto offset = (size_t)outputs[outputOffset++];
|
||||
size_t *sizes = (size_t *)&outputs[outputOffset];
|
||||
outputOffset += rank;
|
||||
size_t *strides = (size_t *)&outputs[outputOffset];
|
||||
outputOffset += rank;
|
||||
|
||||
size_t elementWidth = (output.isEncrypted())
|
||||
? clientlib::EncryptedScalarElementWidth
|
||||
: output.shape.width;
|
||||
|
||||
bool sign = (output.isEncrypted()) ? false : output.shape.sign;
|
||||
concretelang::clientlib::TensorData td =
|
||||
clientlib::tensorDataFromMemRef(rank, elementWidth, sign, allocated,
|
||||
aligned, offset, sizes, strides);
|
||||
buffers.push_back(
|
||||
concretelang::clientlib::ScalarOrTensorData(std::move(td)));
|
||||
}
|
||||
}
|
||||
}
|
||||
return clientlib::PublicResult::fromBuffers(clientParameters,
|
||||
std::move(buffers));
|
||||
}
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -90,9 +90,12 @@ public:
|
||||
// server function call
|
||||
auto evaluationKeys = keySet->evaluationKeys();
|
||||
auto publicResult = serverLambda.call(*publicArgument, evaluationKeys);
|
||||
if (!publicResult) {
|
||||
return StringError("failed calling function");
|
||||
}
|
||||
|
||||
// client result decryption
|
||||
return this->decryptResult(*keySet, *publicResult);
|
||||
return this->decryptResult(*keySet, *(publicResult.get()));
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@@ -7,7 +7,6 @@ endif()
|
||||
|
||||
add_mlir_library(
|
||||
ConcretelangServerLib
|
||||
DynamicRankCall.cpp
|
||||
ServerLambda.cpp
|
||||
DynamicModule.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
||||
@@ -1,247 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
// generated: see genDynamicRandCall.py
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/ServerLib/DynamicArityCall.h"
|
||||
#include "concretelang/ServerLib/ServerLambda.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
/// Helper class template that yields an unsigned integer type given a
|
||||
/// size in bytes
|
||||
template <std::size_t size> struct int_type_of_size {};
|
||||
template <> struct int_type_of_size<4> { typedef uint32_t type; };
|
||||
template <> struct int_type_of_size<8> { typedef uint64_t type; };
|
||||
|
||||
/// Converts one function pointer into another
|
||||
// TODO: Not sure this is valid in all implementations / on all
|
||||
// architectures
|
||||
template <typename FnDstT, typename FnSrcT> FnDstT convert_fnptr(FnSrcT src) {
|
||||
static_assert(sizeof(FnDstT) == sizeof(FnSrcT),
|
||||
"Size of function types must match");
|
||||
using inttype = typename int_type_of_size<sizeof(FnDstT)>::type;
|
||||
inttype raw = reinterpret_cast<inttype>(src);
|
||||
return reinterpret_cast<FnDstT>(raw);
|
||||
}
|
||||
|
||||
ScalarOrTensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
|
||||
std::vector<void *> args,
|
||||
size_t rank,
|
||||
size_t element_width,
|
||||
bool is_signed) {
|
||||
using concretelang::clientlib::MemRefDescriptor;
|
||||
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
|
||||
switch (rank) {
|
||||
case 0: {
|
||||
auto m =
|
||||
multi_arity_call(convert_fnptr<uint64_t (*)(void *...)>(func), args);
|
||||
return concretelang::clientlib::ScalarData(m, is_signed, element_width);
|
||||
}
|
||||
case 1: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<1> (*)(void *...)>(func), args);
|
||||
return convert(1, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 2: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<2> (*)(void *...)>(func), args);
|
||||
return convert(2, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 3: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<3> (*)(void *...)>(func), args);
|
||||
return convert(3, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 4: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<4> (*)(void *...)>(func), args);
|
||||
return convert(4, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 5: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<5> (*)(void *...)>(func), args);
|
||||
return convert(5, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 6: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<6> (*)(void *...)>(func), args);
|
||||
return convert(6, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 7: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<7> (*)(void *...)>(func), args);
|
||||
return convert(7, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 8: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<8> (*)(void *...)>(func), args);
|
||||
return convert(8, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 9: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<9> (*)(void *...)>(func), args);
|
||||
return convert(9, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 10: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<10> (*)(void *...)>(func), args);
|
||||
return convert(10, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 11: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<11> (*)(void *...)>(func), args);
|
||||
return convert(11, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 12: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<12> (*)(void *...)>(func), args);
|
||||
return convert(12, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 13: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<13> (*)(void *...)>(func), args);
|
||||
return convert(13, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 14: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<14> (*)(void *...)>(func), args);
|
||||
return convert(14, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 15: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<15> (*)(void *...)>(func), args);
|
||||
return convert(15, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 16: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<16> (*)(void *...)>(func), args);
|
||||
return convert(16, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 17: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<17> (*)(void *...)>(func), args);
|
||||
return convert(17, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 18: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<18> (*)(void *...)>(func), args);
|
||||
return convert(18, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 19: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<19> (*)(void *...)>(func), args);
|
||||
return convert(19, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 20: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<20> (*)(void *...)>(func), args);
|
||||
return convert(20, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 21: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<21> (*)(void *...)>(func), args);
|
||||
return convert(21, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 22: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<22> (*)(void *...)>(func), args);
|
||||
return convert(22, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 23: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<23> (*)(void *...)>(func), args);
|
||||
return convert(23, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 24: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<24> (*)(void *...)>(func), args);
|
||||
return convert(24, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 25: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<25> (*)(void *...)>(func), args);
|
||||
return convert(25, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 26: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<26> (*)(void *...)>(func), args);
|
||||
return convert(26, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 27: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<27> (*)(void *...)>(func), args);
|
||||
return convert(27, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 28: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<28> (*)(void *...)>(func), args);
|
||||
return convert(28, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 29: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<29> (*)(void *...)>(func), args);
|
||||
return convert(29, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 30: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<30> (*)(void *...)>(func), args);
|
||||
return convert(30, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 31: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<31> (*)(void *...)>(func), args);
|
||||
return convert(31, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 32: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<32> (*)(void *...)>(func), args);
|
||||
return convert(32, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
@@ -9,13 +9,10 @@
|
||||
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
#include "concretelang/ServerLib/DynamicArityCall.h"
|
||||
#include "concretelang/ServerLib/DynamicModule.h"
|
||||
#include "concretelang/ServerLib/DynamicRankCall.h"
|
||||
#include "concretelang/ServerLib/ServerLambda.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/Utils.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
@@ -25,16 +22,17 @@ using concretelang::clientlib::CircuitGateShape;
|
||||
using concretelang::clientlib::EvaluationKeys;
|
||||
using concretelang::clientlib::PublicArguments;
|
||||
using concretelang::error::StringError;
|
||||
using mlir::concretelang::RuntimeContext;
|
||||
using mlir::concretelang::StreamStringError;
|
||||
|
||||
outcome::checked<ServerLambda, StringError>
|
||||
ServerLambda::loadFromModule(std::shared_ptr<DynamicModule> module,
|
||||
std::string funcName) {
|
||||
auto packedFuncName = ::concretelang::makePackedFunctionName(funcName);
|
||||
ServerLambda lambda;
|
||||
lambda.module =
|
||||
module; // prevent module and library handler from being destroyed
|
||||
lambda.func =
|
||||
(void *(*)(void *, ...))dlsym(module->libraryHandle, funcName.c_str());
|
||||
lambda.func = (void (*)(void *, ...))dlsym(module->libraryHandle,
|
||||
packedFuncName.c_str());
|
||||
|
||||
if (auto err = dlerror()) {
|
||||
return StringError("Cannot open lambda:") << std::string(err);
|
||||
@@ -66,29 +64,22 @@ ServerLambda::load(std::string funcName, std::string outputPath) {
|
||||
return ServerLambda::loadFromModule(module, funcName);
|
||||
}
|
||||
|
||||
std::unique_ptr<clientlib::PublicResult>
|
||||
llvm::Error ServerLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
|
||||
auto found = std::find(args.begin(), args.end(), nullptr);
|
||||
if (found == args.end()) {
|
||||
assert(func != nullptr && "func pointer shouldn't be null");
|
||||
func(args.data());
|
||||
return llvm::Error::success();
|
||||
}
|
||||
int pos = found - args.begin();
|
||||
return StreamStringError("invoke: argument at pos ")
|
||||
<< pos << " is null or missing";
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
|
||||
std::vector<void *> preparedArgs(args.preparedArgs.begin(),
|
||||
args.preparedArgs.end());
|
||||
|
||||
RuntimeContext runtimeContext(evaluationKeys);
|
||||
preparedArgs.push_back((void *)&runtimeContext);
|
||||
|
||||
assert(clientParameters.outputs.size() == 1 &&
|
||||
"ServerLambda::call is implemented for only one output");
|
||||
auto output = args.clientParameters.outputs[0];
|
||||
auto rank = args.clientParameters.bufferShape(output).size();
|
||||
|
||||
size_t element_width = (output.isEncrypted()) ? 64 : output.shape.width;
|
||||
bool sign = (output.isEncrypted()) ? false : output.shape.sign;
|
||||
auto result = multi_arity_call_dynamic_rank(func, preparedArgs, rank,
|
||||
element_width, sign);
|
||||
|
||||
std::vector<ScalarOrTensorData> results;
|
||||
results.push_back(std::move(result));
|
||||
|
||||
return clientlib::PublicResult::fromBuffers(clientParameters,
|
||||
std::move(results));
|
||||
return invokeRawOnLambda(this, args.clientParameters, args.preparedArgs,
|
||||
evaluationKeys);
|
||||
}
|
||||
|
||||
} // namespace serverlib
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
|
||||
|
||||
print(
|
||||
"""// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
// generated: see genDynamicRandCall.py
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/ServerLib/DynamicArityCall.h"
|
||||
#include "concretelang/ServerLib/ServerLambda.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
/// Helper class template that yields an unsigned integer type given a
|
||||
/// size in bytes
|
||||
template <std::size_t size> struct int_type_of_size {};
|
||||
template <> struct int_type_of_size<4> { typedef uint32_t type; };
|
||||
template <> struct int_type_of_size<8> { typedef uint64_t type; };
|
||||
|
||||
/// Converts one function pointer into another
|
||||
// TODO: Not sure this is valid in all implementations / on all
|
||||
// architectures
|
||||
template <typename FnDstT, typename FnSrcT> FnDstT convert_fnptr(FnSrcT src) {
|
||||
static_assert(sizeof(FnDstT) == sizeof(FnSrcT),
|
||||
"Size of function types must match");
|
||||
using inttype = typename int_type_of_size<sizeof(FnDstT)>::type;
|
||||
inttype raw = reinterpret_cast<inttype>(src);
|
||||
return reinterpret_cast<FnDstT>(raw);
|
||||
}
|
||||
|
||||
ScalarOrTensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
|
||||
std::vector<void *> args, size_t rank,
|
||||
size_t element_width, bool is_signed) {
|
||||
using concretelang::clientlib::MemRefDescriptor;
|
||||
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
|
||||
switch (rank) {
|
||||
case 0: {
|
||||
auto m =
|
||||
multi_arity_call(convert_fnptr<uint64_t (*)(void *...)>(func), args);
|
||||
return concretelang::clientlib::ScalarData(m, is_signed, element_width);
|
||||
}""")
|
||||
|
||||
for tensor_rank in range(1, 33):
|
||||
memref_rank = tensor_rank
|
||||
print(f""" case {tensor_rank}: {{
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<{memref_rank}> (*)(void *...)>(func), args);
|
||||
return convert({memref_rank}, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}}""")
|
||||
|
||||
print("""
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}""")
|
||||
|
||||
print("""
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang""")
|
||||
@@ -11,6 +11,7 @@ add_mlir_library(
|
||||
logging.cpp
|
||||
Jit.cpp
|
||||
LLVMEmitFile.cpp
|
||||
Utils.cpp
|
||||
DEPENDS
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
|
||||
@@ -14,10 +14,10 @@
|
||||
|
||||
#include "concretelang/Common/BitsSize.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/logging.h"
|
||||
#include <concretelang/Support/Utils.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
@@ -76,12 +76,6 @@ llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
|
||||
<< pos << " is null or missing";
|
||||
}
|
||||
|
||||
// memref is a struct which is flattened aligned, allocated pointers, offset,
|
||||
// and two array of rank size for sizes and strides.
|
||||
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) {
|
||||
return 3 + 2 * rank;
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
JITLambda::call(clientlib::PublicArguments &args,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
@@ -107,85 +101,8 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
}
|
||||
#endif
|
||||
|
||||
// invokeRaw needs to have pointers on arguments and a pointers on the result
|
||||
// as last argument.
|
||||
// Prepare the outputs vector to store the output value of the lambda.
|
||||
auto numOutputs = 0;
|
||||
for (auto &output : args.clientParameters.outputs) {
|
||||
auto shape = args.clientParameters.bufferShape(output);
|
||||
if (shape.size() == 0) {
|
||||
// scalar gate
|
||||
numOutputs += 1;
|
||||
} else {
|
||||
// buffer gate
|
||||
numOutputs += numArgOfRankedMemrefCallingConvention(shape.size());
|
||||
}
|
||||
}
|
||||
std::vector<uint64_t> outputs(numOutputs);
|
||||
|
||||
// Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on
|
||||
// inputs and outputs.
|
||||
std::vector<void *> rawArgs(
|
||||
args.preparedArgs.size() + 1 /*runtime context*/ + 1 /* outputs */
|
||||
);
|
||||
size_t i = 0;
|
||||
// Pointers on inputs
|
||||
for (auto &arg : args.preparedArgs) {
|
||||
rawArgs[i++] = &arg;
|
||||
}
|
||||
|
||||
mlir::concretelang::RuntimeContext runtimeContext(evaluationKeys);
|
||||
// Pointer on runtime context, the rawArgs take pointer on actual value that
|
||||
// is passed to the compiled function.
|
||||
auto rtCtxPtr = &runtimeContext;
|
||||
rawArgs[i++] = &rtCtxPtr;
|
||||
|
||||
// Outputs
|
||||
rawArgs[i++] = reinterpret_cast<void *>(outputs.data());
|
||||
|
||||
// Invoke
|
||||
if (auto err = invokeRaw(rawArgs)) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
// Store the result to the PublicResult
|
||||
std::vector<clientlib::ScalarOrTensorData> buffers;
|
||||
{
|
||||
size_t outputOffset = 0;
|
||||
for (auto &output : args.clientParameters.outputs) {
|
||||
auto shape = args.clientParameters.bufferShape(output);
|
||||
if (shape.size() == 0) {
|
||||
// scalar scalar
|
||||
buffers.push_back(concretelang::clientlib::ScalarOrTensorData(
|
||||
concretelang::clientlib::ScalarData(outputs[outputOffset++],
|
||||
output.shape.sign,
|
||||
output.shape.width)));
|
||||
} else {
|
||||
// buffer gate
|
||||
auto rank = shape.size();
|
||||
auto allocated = (uint64_t *)outputs[outputOffset++];
|
||||
auto aligned = (uint64_t *)outputs[outputOffset++];
|
||||
auto offset = (size_t)outputs[outputOffset++];
|
||||
size_t *sizes = (size_t *)&outputs[outputOffset];
|
||||
outputOffset += rank;
|
||||
size_t *strides = (size_t *)&outputs[outputOffset];
|
||||
outputOffset += rank;
|
||||
|
||||
size_t elementWidth = (output.isEncrypted())
|
||||
? clientlib::EncryptedScalarElementWidth
|
||||
: output.shape.width;
|
||||
|
||||
bool sign = (output.isEncrypted()) ? false : output.shape.sign;
|
||||
concretelang::clientlib::TensorData td =
|
||||
clientlib::tensorDataFromMemRef(rank, elementWidth, sign, allocated,
|
||||
aligned, offset, sizes, strides);
|
||||
buffers.push_back(
|
||||
concretelang::clientlib::ScalarOrTensorData(std::move(td)));
|
||||
}
|
||||
}
|
||||
}
|
||||
return clientlib::PublicResult::fromBuffers(args.clientParameters,
|
||||
std::move(buffers));
|
||||
return ::concretelang::invokeRawOnLambda(this, args.clientParameters,
|
||||
args.preparedArgs, evaluationKeys);
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#include <errno.h>
|
||||
|
||||
#include "llvm/MC/SubtargetFeature.h"
|
||||
#include <llvm/IR/IRBuilder.h>
|
||||
#include <llvm/IR/LegacyPassManager.h>
|
||||
#include <llvm/MC/TargetRegistry.h>
|
||||
#include <llvm/Support/ToolOutputFile.h>
|
||||
@@ -15,6 +17,7 @@
|
||||
#include <mlir/Support/FileUtilities.h>
|
||||
|
||||
#include <concretelang/Support/Error.h>
|
||||
#include <concretelang/Support/Utils.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
@@ -22,30 +25,113 @@ namespace concretelang {
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
llvm::TargetMachine *getDefaultTargetMachine() {
|
||||
auto TargetTriple = llvm::sys::getDefaultTargetTriple();
|
||||
string Error;
|
||||
|
||||
auto Target = llvm::TargetRegistry::lookupTarget(TargetTriple, Error);
|
||||
if (!Target) {
|
||||
// Get target machine from current machine and setup LLVM module accordingly
|
||||
std::unique_ptr<llvm::TargetMachine>
|
||||
getTargetMachineAndSetupModule(llvm::Module *llvmModule) {
|
||||
// Setup the machine properties from the current architecture.
|
||||
auto targetTriple = llvm::sys::getDefaultTargetTriple();
|
||||
std::string errorMessage;
|
||||
const auto *target =
|
||||
llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage);
|
||||
if (!target) {
|
||||
llvm::errs() << "NO target: " << errorMessage << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto CPU = "generic";
|
||||
auto Features = "";
|
||||
llvm::TargetOptions opt;
|
||||
return Target->createTargetMachine(TargetTriple, CPU, Features, opt,
|
||||
llvm::Reloc::PIC_);
|
||||
std::string cpu(llvm::sys::getHostCPUName());
|
||||
llvm::SubtargetFeatures features;
|
||||
llvm::StringMap<bool> hostFeatures;
|
||||
|
||||
if (llvm::sys::getHostCPUFeatures(hostFeatures))
|
||||
for (auto &f : hostFeatures)
|
||||
features.AddFeature(f.first(), f.second);
|
||||
|
||||
std::unique_ptr<llvm::TargetMachine> machine(target->createTargetMachine(
|
||||
targetTriple, cpu, features.getString(), {}, llvm::Reloc::PIC_));
|
||||
if (!machine) {
|
||||
llvm::errs() << "Unable to create target machine\n";
|
||||
return nullptr;
|
||||
}
|
||||
llvmModule->setDataLayout(machine->createDataLayout());
|
||||
llvmModule->setTargetTriple(targetTriple);
|
||||
return machine;
|
||||
}
|
||||
|
||||
// This function was copied from the MLIR Execution Engine, and provide an
|
||||
// elegant and generic invocation interface to the compiled circuit:
|
||||
// For each function in the LLVM module, define an interface function that wraps
|
||||
// all the arguments of the original function and all its results into an i8**
|
||||
// pointer to provide a unified invocation interface.
|
||||
static void packFunctionArguments(llvm::Module *module) {
|
||||
auto &ctx = module->getContext();
|
||||
llvm::IRBuilder<> builder(ctx);
|
||||
llvm::DenseSet<llvm::Function *> interfaceFunctions;
|
||||
for (auto &func : module->getFunctionList()) {
|
||||
if (func.isDeclaration()) {
|
||||
continue;
|
||||
}
|
||||
if (interfaceFunctions.count(&func)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Given a function `foo(<...>)`, define the interface function
|
||||
// `mlir_foo(i8**)`.
|
||||
auto *newType = llvm::FunctionType::get(
|
||||
builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
|
||||
/*isVarArg=*/false);
|
||||
auto newName = ::concretelang::makePackedFunctionName(func.getName());
|
||||
auto funcCst = module->getOrInsertFunction(newName, newType);
|
||||
llvm::Function *interfaceFunc =
|
||||
llvm::cast<llvm::Function>(funcCst.getCallee());
|
||||
interfaceFunctions.insert(interfaceFunc);
|
||||
|
||||
// Extract the arguments from the type-erased argument list and cast them to
|
||||
// the proper types.
|
||||
auto *bb = llvm::BasicBlock::Create(ctx);
|
||||
bb->insertInto(interfaceFunc);
|
||||
builder.SetInsertPoint(bb);
|
||||
llvm::Value *argList = interfaceFunc->arg_begin();
|
||||
llvm::SmallVector<llvm::Value *, 8> args;
|
||||
args.reserve(llvm::size(func.args()));
|
||||
for (auto &indexedArg : llvm::enumerate(func.args())) {
|
||||
llvm::Value *argIndex = llvm::Constant::getIntegerValue(
|
||||
builder.getInt64Ty(), llvm::APInt(64, indexedArg.index()));
|
||||
llvm::Value *argPtrPtr =
|
||||
builder.CreateGEP(builder.getInt8PtrTy(), argList, argIndex);
|
||||
llvm::Value *argPtr =
|
||||
builder.CreateLoad(builder.getInt8PtrTy(), argPtrPtr);
|
||||
llvm::Type *argTy = indexedArg.value().getType();
|
||||
argPtr = builder.CreateBitCast(argPtr, argTy->getPointerTo());
|
||||
llvm::Value *arg = builder.CreateLoad(argTy, argPtr);
|
||||
args.push_back(arg);
|
||||
}
|
||||
|
||||
// Call the implementation function with the extracted arguments.
|
||||
llvm::Value *result = builder.CreateCall(&func, args);
|
||||
|
||||
// Assuming the result is one value, potentially of type `void`.
|
||||
if (!result->getType()->isVoidTy()) {
|
||||
llvm::Value *retIndex = llvm::Constant::getIntegerValue(
|
||||
builder.getInt64Ty(), llvm::APInt(64, llvm::size(func.args())));
|
||||
llvm::Value *retPtrPtr =
|
||||
builder.CreateGEP(builder.getInt8PtrTy(), argList, retIndex);
|
||||
llvm::Value *retPtr =
|
||||
builder.CreateLoad(builder.getInt8PtrTy(), retPtrPtr);
|
||||
retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo());
|
||||
builder.CreateStore(result, retPtr);
|
||||
}
|
||||
|
||||
// The interface function returns void.
|
||||
builder.CreateRetVoid();
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Error emitObject(llvm::Module &module, string objectPath) {
|
||||
auto targetMachine = getDefaultTargetMachine();
|
||||
auto targetMachine = getTargetMachineAndSetupModule(&module);
|
||||
if (!targetMachine) {
|
||||
return StreamStringError("No default target machine for object generation");
|
||||
}
|
||||
|
||||
module.setDataLayout(targetMachine->createDataLayout());
|
||||
|
||||
string Error;
|
||||
std::unique_ptr<llvm::ToolOutputFile> objectFile =
|
||||
mlir::openOutputFile(objectPath, &Error);
|
||||
@@ -53,6 +139,8 @@ llvm::Error emitObject(llvm::Module &module, string objectPath) {
|
||||
return StreamStringError("Cannot create/open " + objectPath);
|
||||
}
|
||||
|
||||
packFunctionArguments(&module);
|
||||
|
||||
// The legacy PassManager is mandatory for final code generation.
|
||||
// https://llvm.org/docs/NewPassManager.html#status-of-the-new-and-legacy-pass-managers
|
||||
llvm::legacy::PassManager pm;
|
||||
@@ -67,7 +155,6 @@ llvm::Error emitObject(llvm::Module &module, string objectPath) {
|
||||
objectFile->os().flush();
|
||||
objectFile->os().close();
|
||||
objectFile->keep();
|
||||
delete targetMachine;
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
|
||||
18
compilers/concrete-compiler/compiler/lib/Support/Utils.cpp
Normal file
18
compilers/concrete-compiler/compiler/lib/Support/Utils.cpp
Normal file
@@ -0,0 +1,18 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <concretelang/Support/Utils.h>
|
||||
|
||||
namespace concretelang {
|
||||
|
||||
std::string makePackedFunctionName(llvm::StringRef name) {
|
||||
return "_mlir_" + name.str();
|
||||
}
|
||||
|
||||
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) {
|
||||
return 3 + 2 * rank;
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
Reference in New Issue
Block a user