refactor(jit): Use PublicArguments instead of JitLambda::Argument to call the lambda (uniform calling to ServerLambda and JitLambda)

This commit is contained in:
Quentin Bourgerie
2022-03-02 14:50:02 +01:00
parent 8a52cdaaf5
commit 0d1f041323
20 changed files with 485 additions and 325 deletions

View File

@@ -82,19 +82,18 @@ public:
/// ServerLambda::real_call_write function. ostream must be in binary mode
/// std::ios_base::openmode::binary
outcome::checked<void, StringError>
serializeCall(Args... args, std::shared_ptr<KeySet> keySet,
std::ostream &ostream) {
serializeCall(Args... args, KeySet &keySet, std::ostream &ostream) {
OUTCOME_TRY(auto publicArguments, publicArguments(args..., keySet));
return publicArguments->serialize(ostream);
}
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
publicArguments(Args... args, std::shared_ptr<KeySet> keySet) {
publicArguments(Args... args, KeySet &keySet) {
OUTCOME_TRY(auto clientArguments,
EncryptedArguments::create(keySet, args...));
return clientArguments->exportPublicArguments(clientParameters,
keySet->runtimeContext());
keySet.runtimeContext());
}
outcome::checked<Result, StringError> decryptResult(KeySet &keySet,

View File

@@ -120,6 +120,8 @@ static inline bool operator==(const CircuitGateShape &lhs,
struct CircuitGate {
llvm::Optional<EncryptionGate> encryption;
CircuitGateShape shape;
bool isEncrypted() { return encryption.hasValue(); }
};
static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape;
@@ -140,7 +142,32 @@ struct ClientParameters {
static std::string getClientParametersPath(std::string path);
LweSecretKeyParam lweSecretKeyParam(CircuitGate gate);
outcome::checked<CircuitGate, StringError> input(size_t pos) {
if (pos >= inputs.size()) {
return StringError("input gate ") << pos << " didn't exists";
}
return inputs[pos];
}
outcome::checked<CircuitGate, StringError> ouput(size_t pos) {
if (pos >= outputs.size()) {
return StringError("output gate ") << pos << " didn't exists";
}
return outputs[pos];
}
outcome::checked<LweSecretKeyParam, StringError>
lweSecretKeyParam(CircuitGate gate) {
if (!gate.encryption.hasValue()) {
return StringError("gate is not encrypted");
}
auto secretKey = secretKeys.find(gate.encryption->secretKeyID);
if (secretKey == secretKeys.end()) {
return StringError("cannot find ")
<< gate.encryption->secretKeyID << " in client parameters";
}
return secretKey->second;
}
};
static inline bool operator==(const ClientParameters &lhs,

View File

@@ -35,12 +35,16 @@ public:
/// an EncryptedArguments
template <typename... Args>
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
create(std::shared_ptr<KeySet> keySet, Args... args) {
create(KeySet &keySet, Args... args) {
auto arguments = std::make_unique<EncryptedArguments>();
OUTCOME_TRYV(arguments->pushArgs(keySet, args...));
return arguments;
}
static std::unique_ptr<EncryptedArguments> empty() {
return std::make_unique<EncryptedArguments>();
}
/// Export encrypted arguments as public arguments, reset the encrypted
/// arguments, i.e. move all buffers to the PublicArguments and reset the
/// positional counter.
@@ -48,31 +52,44 @@ public:
exportPublicArguments(ClientParameters clientParameters,
RuntimeContext runtimeContext);
public:
/// Add a uint8_t scalar argument.
outcome::checked<void, StringError> pushArg(uint8_t arg,
std::shared_ptr<KeySet> keySet);
/// Check that all arguments as been pushed.
/// TODO: Remove public method here
outcome::checked<void, StringError> checkAllArgs(KeySet &keySet);
public:
// Add a uint64_t scalar argument.
outcome::checked<void, StringError> pushArg(uint64_t arg,
std::shared_ptr<KeySet> keySet);
outcome::checked<void, StringError> pushArg(uint64_t arg, KeySet &keySet);
/// Add a vector-tensor argument.
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
std::shared_ptr<KeySet> keySet);
KeySet &keySet);
// Add a 1D tensor argument with data and size of the dimension.
template <typename T>
outcome::checked<void, StringError> pushArg(const T *data, int64_t dim1,
KeySet &keySet) {
return pushArg(std::vector<uint8_t>(data, data + dim1), keySet);
}
// Add a tensor argument.
template <typename T>
outcome::checked<void, StringError>
pushArg(const T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
return pushArg(8 * sizeof(T), static_cast<const void *>(data), shape,
keySet);
}
/// Add a 1D tensor argument.
template <size_t size>
outcome::checked<void, StringError> pushArg(std::array<uint8_t, size> arg,
std::shared_ptr<KeySet> keySet) {
KeySet &keySet) {
return pushArg(8, (void *)arg.data(), {size}, keySet);
}
/// Add a 2D tensor argument.
template <size_t size0, size_t size1>
outcome::checked<void, StringError>
pushArg(std::array<std::array<uint8_t, size1>, size0> arg,
std::shared_ptr<KeySet> keySet) {
pushArg(std::array<std::array<uint8_t, size1>, size0> arg, KeySet &keySet) {
return pushArg(8, (void *)arg.data(), {size0, size1}, keySet);
}
@@ -80,7 +97,7 @@ public:
template <size_t size0, size_t size1, size_t size2>
outcome::checked<void, StringError>
pushArg(std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg,
std::shared_ptr<KeySet> keySet) {
KeySet &keySet) {
return pushArg(8, (void *)arg.data(), {size0, size1, size2}, keySet);
}
@@ -88,41 +105,48 @@ public:
// Set a argument at the given pos as a 1D tensor of T.
template <typename T>
outcome::checked<void, StringError> pushArg(T *data, size_t dim1,
std::shared_ptr<KeySet> keySet) {
return pushArg<T>(data, llvm::ArrayRef<size_t>(&dim1, 1), keySet);
outcome::checked<void, StringError> pushArg(T *data, int64_t dim1,
KeySet &keySet) {
return pushArg<T>(data, llvm::ArrayRef<int64_t>(&dim1, 1), keySet);
}
// Set a argument at the given pos as a tensor of T.
template <typename T>
outcome::checked<void, StringError> pushArg(T *data,
llvm::ArrayRef<int64_t> shape,
std::shared_ptr<KeySet> keySet) {
return pushArg(8 * sizeof(T), static_cast<void *>(data), shape, keySet);
outcome::checked<void, StringError>
pushArg(T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
return pushArg(8 * sizeof(T), static_cast<const void *>(data), shape,
keySet);
}
outcome::checked<void, StringError> pushArg(size_t width, void *data,
outcome::checked<void, StringError> pushArg(size_t width, const void *data,
llvm::ArrayRef<int64_t> shape,
std::shared_ptr<KeySet> keySet);
KeySet &keySet);
/// Push a variadic list of arguments.
// Recursive case for scalars: extract first scalar argument from
// parameter pack and forward rest
template <typename Arg0, typename... OtherArgs>
outcome::checked<void, StringError> pushArgs(std::shared_ptr<KeySet> keySet,
Arg0 arg0, OtherArgs... others) {
outcome::checked<void, StringError> pushArgs(KeySet &keySet, Arg0 arg0,
OtherArgs... others) {
OUTCOME_TRYV(pushArg(arg0, keySet));
return pushArgs(keySet, others...);
}
// Recursive case for tensors: extract pointer and size from
// parameter pack and forward rest
template <typename Arg0, typename... OtherArgs>
outcome::checked<void, StringError>
pushArgs(KeySet &keySet, Arg0 *arg0, size_t size, OtherArgs... others) {
OUTCOME_TRYV(pushArg(arg0, size, keySet));
return pushArgs(keySet, others...);
}
// Terminal case of pushArgs
outcome::checked<void, StringError> pushArgs(std::shared_ptr<KeySet> keySet) {
outcome::checked<void, StringError> pushArgs(KeySet &keySet) {
return checkAllArgs(keySet);
}
private:
outcome::checked<void, StringError>
checkPushTooManyArgs(std::shared_ptr<KeySet> keySet);
outcome::checked<void, StringError>
checkAllArgs(std::shared_ptr<KeySet> keySet);
outcome::checked<void, StringError> checkPushTooManyArgs(KeySet &keySet);
private:
// Position of the next pushed argument

View File

@@ -21,6 +21,11 @@ namespace serverlib {
class ServerLambda;
}
} // namespace concretelang
namespace mlir {
namespace concretelang {
class JITLambda;
}
} // namespace mlir
namespace concretelang {
namespace clientlib {
@@ -45,7 +50,8 @@ public:
outcome::checked<void, StringError> serialize(std::ostream &ostream);
private:
friend class ::concretelang::serverlib::ServerLambda; // from ServerLib
friend class ::concretelang::serverlib::ServerLambda;
friend class ::mlir::concretelang::JITLambda;
outcome::checked<void, StringError> unserializeArgs(std::istream &istream);
@@ -82,16 +88,27 @@ struct PublicResult {
/// Serialize into an output stream.
outcome::checked<void, StringError> serialize(std::ostream &ostream);
/// Decrypt the result at `pos` as a vector.
/// Get the result at `pos` as a vector, if the result is a scalar returns a
/// vector of size 1. Decryption happens if the result is encrypted.
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
decryptVector(KeySet &keySet, size_t pos);
asClearTextVector(KeySet &keySet, size_t pos);
private:
// private: TODO tmp
friend class ::concretelang::serverlib::ServerLambda;
ClientParameters clientParameters;
std::vector<TensorData> buffers;
};
/// Helper function to convert from a scalar to TensorData
TensorData tensorDataFromScalar(uint64_t value);
/// Helper function to convert from MemRefDescriptor to
/// TensorData
TensorData tensorDataFromMemRef(size_t memref_rank,
encrypted_scalars_t allocated,
encrypted_scalars_t aligned, size_t offset,
size_t *sizes, size_t *strides);
} // namespace clientlib
} // namespace concretelang

View File

@@ -56,7 +56,7 @@ std::ostream &operator<<(std::ostream &ostream,
const RuntimeContext &runtimeContext);
std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext);
std::ostream &serializeTensorData(std::vector<size_t> &sizes, uint64_t *values,
std::ostream &serializeTensorData(std::vector<int64_t> &sizes, uint64_t *values,
std::ostream &ostream);
std::ostream &serializeTensorData(TensorData &values_and_sizes,

View File

@@ -35,7 +35,7 @@ using encrypted_scalars_t = uint64_t *;
struct TensorData {
std::vector<uint64_t> values; // tensor of rank r + 1
std::vector<size_t> sizes; // r sizes
std::vector<int64_t> sizes; // r sizes
inline size_t length() {
if (sizes.empty()) {

View File

@@ -23,10 +23,6 @@ using concretelang::clientlib::encrypted_scalar_t;
using concretelang::clientlib::encrypted_scalars_t;
using concretelang::clientlib::TensorData;
TensorData TensorData_from_MemRef(size_t rank, encrypted_scalars_t allocated,
encrypted_scalars_t aligned, size_t offset,
size_t *sizes, size_t *strides);
/// ServerLambda is a utility class that allows to call a function of a
/// compilation result.
class ServerLambda {

View File

@@ -11,12 +11,14 @@
#include <mlir/Support/LogicalResult.h>
#include <concretelang/ClientLib/KeySet.h>
#include <concretelang/ClientLib/PublicArguments.h>
namespace mlir {
namespace concretelang {
using ::concretelang::clientlib::CircuitGate;
using ::concretelang::clientlib::KeySet;
namespace clientlib = ::concretelang::clientlib;
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
/// of the module.
@@ -118,6 +120,11 @@ public:
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
/// Call the JIT lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
call(clientlib::PublicArguments &args);
private:
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
/// used to store the result of the computation.
/// Example:
@@ -127,9 +134,6 @@ public:
/// lambda.invokeRaw(args);
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
/// invoke the jit lambda with the Argument.
llvm::Error invoke(Argument &args);
private:
mlir::LLVM::LLVMFunctionType type;
std::string name;

View File

@@ -17,6 +17,7 @@ namespace mlir {
namespace concretelang {
using ::concretelang::clientlib::KeySetCache;
namespace clientlib = ::concretelang::clientlib;
namespace {
// Generic function template as well as specializations of
@@ -26,34 +27,35 @@ namespace {
// Helper function for `JitCompilerEngine::Lambda::operator()`
// implementing type-dependent preparation of the result.
template <typename ResT>
llvm::Expected<ResT> typedResult(JITLambda::Argument &arguments);
llvm::Expected<ResT> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result);
// Specialization of `typedResult()` for scalar results, forwarding
// scalar value to caller
template <>
inline llvm::Expected<uint64_t> typedResult(JITLambda::Argument &arguments) {
uint64_t res = 0;
if (auto err = arguments.getResult(0, res))
return StreamStringError() << "Cannot retrieve result:" << err;
return res;
inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
auto clearResult = result.asClearTextVector(keySet, 0);
if (!clearResult.has_value()) {
return StreamStringError("typedResult cannot get clear text vector")
<< clearResult.error().mesg;
}
if (clearResult.value().size() != 1) {
return StreamStringError("typedResult expect only one value but got ")
<< clearResult.value().size();
}
return clearResult.value()[0];
}
template <typename T>
inline llvm::Expected<std::vector<T>>
typedVectorResult(JITLambda::Argument &arguments) {
llvm::Expected<size_t> n = arguments.getResultVectorSize(0);
if (auto err = n.takeError())
return std::move(err);
std::vector<T> res(*n);
if (auto err = arguments.getResult(0, res.data(), res.size()))
return StreamStringError() << "Cannot retrieve result:" << err;
return std::move(res);
typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
auto clearResult = result.asClearTextVector(keySet, 0);
if (!clearResult.has_value()) {
return StreamStringError("typedVectorResult cannot get clear text vector")
<< clearResult.error().mesg;
}
return std::move(clearResult.value());
}
// Specializations of `typedResult()` for vector results, initializing
@@ -62,151 +64,144 @@ typedVectorResult(JITLambda::Argument &arguments) {
//
// Cannot factor out into a template template <typename T> inline
// llvm::Expected<std::vector<uint8_t>>
// typedResult(JITLambda::Argument &arguments); due to ambiguity with
// scalar template
template <>
inline llvm::Expected<std::vector<uint8_t>>
typedResult(JITLambda::Argument &arguments) {
return typedVectorResult<uint8_t>(arguments);
}
template <>
inline llvm::Expected<std::vector<uint16_t>>
typedResult(JITLambda::Argument &arguments) {
return typedVectorResult<uint16_t>(arguments);
}
template <>
inline llvm::Expected<std::vector<uint32_t>>
typedResult(JITLambda::Argument &arguments) {
return typedVectorResult<uint32_t>(arguments);
}
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
// to ambiguity with scalar template
// template <>
// inline llvm::Expected<std::vector<uint8_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// return typedVectorResult<uint8_t>(keySet, result);
// }
// template <>
// inline llvm::Expected<std::vector<uint16_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// return typedVectorResult<uint16_t>(keySet, result);
// }
// template <>
// inline llvm::Expected<std::vector<uint32_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// return typedVectorResult<uint32_t>(keySet, result);
// }
template <>
inline llvm::Expected<std::vector<uint64_t>>
typedResult(JITLambda::Argument &arguments) {
return typedVectorResult<uint64_t>(arguments);
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<uint64_t>(keySet, result);
}
template <typename T>
llvm::Expected<std::unique_ptr<LambdaArgument>>
buildTensorLambdaResult(JITLambda::Argument &arguments) {
buildTensorLambdaResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
llvm::Expected<std::vector<T>> tensorOrError =
typedResult<std::vector<T>>(arguments);
typedResult<std::vector<T>>(keySet, result);
if (!tensorOrError)
return std::move(tensorOrError.takeError());
if (auto err = tensorOrError.takeError())
return std::move(err);
std::vector<int64_t> tensorDim(result.buffers[0].sizes.begin(),
result.buffers[0].sizes.end() - 1);
llvm::Expected<std::vector<int64_t>> tensorDimOrError =
arguments.getResultDimensions(0);
if (!tensorDimOrError)
return tensorDimOrError.takeError();
return std::move(std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
*tensorOrError, *tensorDimOrError));
return std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
*tensorOrError, tensorDim);
}
// Specialization of `typedResult()` for a single result wrapped into
// a `LambdaArgument`.
template <>
inline llvm::Expected<std::unique_ptr<LambdaArgument>>
typedResult(JITLambda::Argument &arguments) {
llvm::Expected<enum JITLambda::Argument::ResultType> resTy =
arguments.getResultType(0);
if (!resTy)
return resTy.takeError();
switch (*resTy) {
case JITLambda::Argument::ResultType::SCALAR: {
uint64_t res;
if (llvm::Error err = arguments.getResult(0, res))
return std::move(err);
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
auto gate = keySet.outputGate(0);
// scalar case
if (gate.shape.dimensions.empty()) {
auto clearResult = result.asClearTextVector(keySet, 0);
if (clearResult.has_error()) {
return StreamStringError("typedResult: ") << clearResult.error().mesg;
}
auto res = clearResult.value()[0];
return std::make_unique<IntLambdaArgument<uint64_t>>(res);
}
// tensor case
// auto width = gate.shape.width;
case JITLambda::Argument::ResultType::TENSOR: {
llvm::Expected<size_t> width = arguments.getResultWidth(0);
// if (width > 32)
return buildTensorLambdaResult<uint64_t>(keySet, result);
// else if (width > 16)
// return buildTensorLambdaResult<uint32_t>(keySet, result);
// else if (width > 8)
// return buildTensorLambdaResult<uint16_t>(keySet, result);
// else if (width <= 8)
// return buildTensorLambdaResult<uint8_t>(keySet, result);
if (!width)
return width.takeError();
// return StreamStringError("Cannot handle scalars with more than 64 bits");
}
if (*width > 64)
return StreamStringError("Cannot handle scalars with more than 64 bits");
if (*width > 32)
return buildTensorLambdaResult<uint64_t>(arguments);
else if (*width > 16)
return buildTensorLambdaResult<uint32_t>(arguments);
else if (*width > 8)
return buildTensorLambdaResult<uint16_t>(arguments);
else if (*width <= 8)
return buildTensorLambdaResult<uint8_t>(arguments);
}
}
return StreamStringError("Unknown result type");
} // namespace
// Adaptor class that adds arguments specified as instances of
// `LambdaArgument` to `JitLambda::Argument`.
// Adaptor class that push arguments specified as instances of
// `LambdaArgument` to `clientlib::EncryptedArguments`.
class JITLambdaArgumentAdaptor {
public:
// Checks if the argument `arg` is an plaintext / encrypted integer
// argument or a plaintext / encrypted tensor argument with a
// backing integer type `IntT` and adds the argument to `jla` at
// position `pos`.
// backing integer type `IntT` and push the argument to `encryptedArgs`.
//
// Returns `true` if `arg` has one of the types above and its value
// was successfully added to `jla`, `false` if none of the types
// was successfully added to `encryptedArgs`, `false` if none of the types
// matches or an error if a type matched, but adding the argument to
// `jla` failed.
// `encryptedArgs` failed.
template <typename IntT>
static inline llvm::Expected<bool>
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
const LambdaArgument &arg, clientlib::KeySet &keySet) {
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
if (llvm::Error err = jla.setArg(pos, ila->getValue()))
return std::move(err);
else
auto res = encryptedArgs.pushArg(ila->getValue(), keySet);
if (!res.has_value()) {
return StreamStringError(res.error().mesg);
} else {
return true;
}
} else if (auto tla = arg.dyn_cast<
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
if (llvm::Error err =
jla.setArg(pos, tla->getValue(), tla->getDimensions()))
return std::move(err);
else
auto res =
encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet);
if (!res.has_value()) {
return StreamStringError(res.error().mesg);
} else {
return true;
}
}
return false;
}
// Recursive case for `tryAddArg<IntT>(...)`
template <typename IntT, typename NextIntT, typename... IntTs>
static inline llvm::Expected<bool>
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
llvm::Expected<bool> successOrError = tryAddArg<IntT>(jla, pos, arg);
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
const LambdaArgument &arg, clientlib::KeySet &keySet) {
llvm::Expected<bool> successOrError =
tryAddArg<IntT>(encryptedArgs, arg, keySet);
if (!successOrError)
return successOrError.takeError();
if (successOrError.get() == false)
return tryAddArg<NextIntT, IntTs...>(jla, pos, arg);
return tryAddArg<NextIntT, IntTs...>(encryptedArgs, arg, keySet);
else
return true;
}
// Attempts to add a single argument `arg` to `jla` at position
// `pos`. Returns an error if either the argument type is
// unsupported or if the argument types is supported, but adding it
// to `jla` failed.
static inline llvm::Error addArgument(JITLambda::Argument &jla, size_t pos,
const LambdaArgument &arg) {
// Attempts to push a single argument `arg` to `encryptedArgs`. Returns an
// error if either the argument type is unsupported or if the argument types
// is supported, but adding it to `encryptedArgs` failed.
static inline llvm::Error
addArgument(clientlib::EncryptedArguments &encryptedArgs,
const LambdaArgument &arg, clientlib::KeySet &keySet) {
// Try the supported integer types; size_t needs explicit
// treatment, since it may alias none of the fixed size integer
// types
llvm::Expected<bool> successOrError =
JITLambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t,
uint8_t, size_t>(jla, pos, arg);
uint8_t, size_t>(encryptedArgs, arg,
keySet);
if (!successOrError)
return successOrError.takeError();
@@ -217,7 +212,6 @@ public:
return llvm::Error::success();
}
};
} // namespace
// A compiler engine that JIT-compiles a source and produces a lambda
// object directly invocable through its call operator.
@@ -231,12 +225,15 @@ public:
Lambda(Lambda &&other)
: innerLambda(std::move(other.innerLambda)),
keySet(std::move(other.keySet)),
compilationContext(other.compilationContext) {}
compilationContext(other.compilationContext),
clientParameters(other.clientParameters) {}
Lambda(std::shared_ptr<CompilationContext> compilationContext,
std::unique_ptr<JITLambda> lambda, std::unique_ptr<KeySet> keySet)
std::unique_ptr<JITLambda> lambda, std::unique_ptr<KeySet> keySet,
clientlib::ClientParameters clientParameters)
: innerLambda(std::move(lambda)), keySet(std::move(keySet)),
compilationContext(compilationContext) {}
compilationContext(compilationContext),
clientParameters(clientParameters) {}
// Returns the number of arguments required for an invocation of
// the lambda
@@ -251,81 +248,96 @@ public:
template <typename ResT = uint64_t>
llvm::Expected<ResT>
operator()(llvm::ArrayRef<LambdaArgument *> lambdaArgs) {
// Create the arguments of the JIT lambda
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
mlir::concretelang::JITLambda::Argument::create(*this->keySet.get());
if (llvm::Error err = argsOrErr.takeError())
return StreamStringError("Could not create lambda arguments");
// Set the arguments
std::unique_ptr<JITLambda::Argument> arguments =
std::move(argsOrErr.get());
// Encrypt the arguments
auto encryptedArgs = clientlib::EncryptedArguments::empty();
for (size_t i = 0; i < lambdaArgs.size(); i++) {
if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument(
*arguments, i, *lambdaArgs[i])) {
*encryptedArgs, *lambdaArgs[i], *this->keySet)) {
return std::move(err);
}
}
// Invoke the lambda
if (auto err = this->innerLambda->invoke(*arguments))
return StreamStringError() << "Cannot invoke lambda:" << err;
auto check = encryptedArgs->checkAllArgs(*this->keySet);
if (check.has_error()) {
return StreamStringError(check.error().mesg);
}
return std::move(typedResult<ResT>(*arguments));
// Export as public arguments
auto publicArguments = encryptedArgs->exportPublicArguments(
clientParameters, keySet->runtimeContext());
if (!publicArguments.has_value()) {
return StreamStringError(publicArguments.error().mesg);
}
// Call the lambda
auto publicResult = this->innerLambda->call(*publicArguments.value());
if (auto err = publicResult.takeError()) {
return std::move(err);
}
return typedResult<ResT>(*keySet, **publicResult);
}
// Invocation with an array of arguments of the same type
template <typename T, typename ResT = uint64_t>
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
// Create the arguments of the JIT lambda
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
mlir::concretelang::JITLambda::Argument::create(*this->keySet.get());
if (llvm::Error err = argsOrErr.takeError())
return StreamStringError("Could not create lambda arguments");
// Set the arguments
std::unique_ptr<JITLambda::Argument> arguments =
std::move(argsOrErr.get());
// Encrypt the arguments
auto encryptedArgs = clientlib::EncryptedArguments::empty();
for (size_t i = 0; i < args.size(); i++) {
if (auto err = arguments->setArg(i, args[i])) {
return StreamStringError()
<< "Cannot push argument " << i << ": " << err;
auto res = encryptedArgs->pushArg(args[i], *keySet);
if (res.has_error()) {
return StreamStringError(res.error().mesg);
}
}
// Invoke the lambda
if (auto err = this->innerLambda->invoke(*arguments))
return StreamStringError() << "Cannot invoke lambda:" << err;
auto check = encryptedArgs->checkAllArgs(*this->keySet);
if (check.has_error()) {
return StreamStringError(check.error().mesg);
}
return std::move(typedResult<ResT>(*arguments));
// Export as public arguments
auto publicArguments = encryptedArgs->exportPublicArguments(
clientParameters, keySet->runtimeContext());
if (!publicArguments.has_value()) {
return StreamStringError(publicArguments.error().mesg);
}
// Call the lambda
auto publicResult = this->innerLambda->call(*publicArguments.value());
if (auto err = publicResult.takeError()) {
return std::move(err);
}
return typedResult<ResT>(*keySet, **publicResult);
}
// Invocation with arguments of different types
template <typename ResT = uint64_t, typename... Ts>
llvm::Expected<ResT> operator()(const Ts... ts) {
// Create the arguments of the JIT lambda
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
mlir::concretelang::JITLambda::Argument::create(*this->keySet.get());
// Encrypt the arguments
auto encryptedArgs =
clientlib::EncryptedArguments::create(*keySet, ts...);
if (llvm::Error err = argsOrErr.takeError())
return StreamStringError("Could not create lambda arguments");
if (encryptedArgs.has_error()) {
return StreamStringError(encryptedArgs.error().mesg);
}
// Set the arguments
std::unique_ptr<JITLambda::Argument> arguments =
std::move(argsOrErr.get());
// Export as public arguments
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
clientParameters, keySet->runtimeContext());
if (!publicArguments.has_value()) {
return StreamStringError(publicArguments.error().mesg);
}
if (llvm::Error err = this->addArgs<0>(arguments.get(), ts...))
// Call the lambda
auto publicResult = this->innerLambda->call(*publicArguments.value());
if (auto err = publicResult.takeError()) {
return std::move(err);
}
// Invoke the lambda
if (auto err = this->innerLambda->invoke(*arguments))
return StreamStringError() << "Cannot invoke lambda:" << err;
return std::move(typedResult<ResT>(*arguments));
return typedResult<ResT>(*keySet, **publicResult);
}
protected:
@@ -364,6 +376,7 @@ public:
std::unique_ptr<JITLambda> innerLambda;
std::unique_ptr<KeySet> keySet;
std::shared_ptr<CompilationContext> compilationContext;
const clientlib::ClientParameters clientParameters;
};
JitCompilerEngine(std::shared_ptr<CompilationContext> compilationContext =

View File

@@ -63,12 +63,29 @@ public:
keySet(keySet) {}
outcome::checked<Result, StringError> call(Args... args) {
// std::string message;
// client stream
// std::ostringstream clientOuput(std::ios::binary);
// client argument encryption
OUTCOME_TRY(auto encryptedArgs,
clientlib::EncryptedArguments::create(keySet, args...));
clientlib::EncryptedArguments::create(*keySet, args...));
OUTCOME_TRY(auto publicArgument,
encryptedArgs->exportPublicArguments(this->clientParameters,
keySet->runtimeContext()));
// client argument serialization
// publicArgument->serialize(clientOuput);
// message = clientOuput.str();
// server stream
// std::istringstream serverInput(message, std::ios::binary);
// freeStringMemory(message);
//
// OUTCOME_TRY(auto publicArguments,
// clientlib::PublicArguments::unserialize(
// this->clientParameters,
// serverInput));
// server function call
auto publicResult = serverLambda.call(*publicArgument);

View File

@@ -24,11 +24,9 @@ ClientLambda::load(std::string functionName, std::string jsonPath) {
return StringError("ClientLambda: cannot find function ")
<< functionName << " in client parameters" << jsonPath;
}
if (param->outputs.size() != 1) {
return StringError("ClientLambda: output arity (")
<< std::to_string(param->outputs.size())
<< ") != 1 is not supported";
<< std::to_string(param->outputs.size()) << ") != 1 is not supprted";
}
if (!param->outputs[0].encryption.hasValue()) {
@@ -54,7 +52,7 @@ ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) {
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) {
return result.decryptVector(keySet, 0);
return result.asClearTextVector(keySet, 0);
}
outcome::checked<void, StringError> errorResultRank(size_t expected,

View File

@@ -53,10 +53,6 @@ std::size_t ClientParameters::hash() {
return currentHash;
}
LweSecretKeyParam ClientParameters::lweSecretKeyParam(CircuitGate gate) {
return secretKeys.find(gate.encryption->secretKeyID)->second;
}
llvm::json::Value toJSON(const LweSecretKeyParam &v) {
llvm::json::Object object{
{"dimension", v.dimension},

View File

@@ -22,34 +22,25 @@ EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
}
outcome::checked<void, StringError>
EncryptedArguments::pushArg(uint8_t arg, std::shared_ptr<KeySet> keySet) {
return pushArg((uint64_t)arg, keySet);
}
outcome::checked<void, StringError>
EncryptedArguments::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
auto pos = currentPos++;
CircuitGate input = keySet->inputGate(pos);
CircuitGate input = keySet.inputGate(pos);
if (input.shape.size != 0) {
return StringError("argument #") << pos << " is not a scalar";
}
if (!input.encryption.hasValue()) {
// clear scalar: just push the argument
if (input.shape.width != 64) {
return StringError(
"scalar argument of with != 64 is not supported for DynamicLambda");
}
preparedArgs.push_back((void *)arg);
return outcome::success();
}
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
TensorData &values_and_sizes = ciphertextBuffers.back();
auto lweSize = keySet->getInputLweSecretKeyParam(pos).lweSize();
auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize();
values_and_sizes.sizes.push_back(lweSize);
values_and_sizes.values.resize(lweSize);
OUTCOME_TRYV(keySet->encrypt_lwe(pos, values_and_sizes.values.data(), arg));
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values_and_sizes.values.data(), arg));
// Note: Since we bufferized lwe ciphertext take care of memref calling
// convention
// allocated
@@ -66,18 +57,16 @@ EncryptedArguments::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
}
outcome::checked<void, StringError>
EncryptedArguments::pushArg(std::vector<uint8_t> arg,
std::shared_ptr<KeySet> keySet) {
EncryptedArguments::pushArg(std::vector<uint8_t> arg, KeySet &keySet) {
return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()}, keySet);
}
outcome::checked<void, StringError>
EncryptedArguments::pushArg(size_t width, void *data,
llvm::ArrayRef<int64_t> shape,
std::shared_ptr<KeySet> keySet) {
EncryptedArguments::pushArg(size_t width, const void *data,
llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
auto pos = currentPos;
CircuitGate input = keySet->inputGate(pos);
CircuitGate input = keySet.inputGate(pos);
// Check the width of data
if (input.shape.width > 64) {
return StringError("argument #")
@@ -108,7 +97,7 @@ EncryptedArguments::pushArg(size_t width, void *data,
}
}
if (input.encryption.hasValue()) {
auto lweSize = keySet->getInputLweSecretKeyParam(pos).lweSize();
auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize();
values_and_sizes.sizes.push_back(lweSize);
// Encrypted tensor: for now we support only 8 bits for encrypted tensor
@@ -124,9 +113,14 @@ EncryptedArguments::pushArg(size_t width, void *data,
// Allocate ciphertexts and encrypt, for every values in tensor
for (size_t i = 0, offset = 0; i < input.shape.size;
i++, offset += lweSize) {
OUTCOME_TRYV(keySet->encrypt_lwe(pos, values.data() + offset, data8[i]));
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data8[i]));
}
} // TODO: NON ENCRYPTED, COPY CONTENT TO values_and_sizes
} else {
values_and_sizes.values.resize(input.shape.size);
for (size_t i = 0; i < input.shape.size; i++) {
values_and_sizes.values[i] = ((const uint64_t *)data)[i];
}
}
// allocated
preparedArgs.push_back(nullptr);
// aligned
@@ -150,8 +144,8 @@ EncryptedArguments::pushArg(size_t width, void *data,
}
outcome::checked<void, StringError>
EncryptedArguments::checkPushTooManyArgs(std::shared_ptr<KeySet> keySet) {
size_t arity = keySet->numInputs();
EncryptedArguments::checkPushTooManyArgs(KeySet &keySet) {
size_t arity = keySet.numInputs();
if (currentPos < arity) {
return outcome::success();
}
@@ -160,8 +154,8 @@ EncryptedArguments::checkPushTooManyArgs(std::shared_ptr<KeySet> keySet) {
}
outcome::checked<void, StringError>
EncryptedArguments::checkAllArgs(std::shared_ptr<KeySet> keySet) {
size_t arity = keySet->numInputs();
EncryptedArguments::checkAllArgs(KeySet &keySet) {
size_t arity = keySet.numInputs();
if (currentPos == arity) {
return outcome::success();
}

View File

@@ -19,10 +19,11 @@ namespace clientlib {
using concretelang::error::StringError;
// TODO: optimize the move
PublicArguments::PublicArguments(
const ClientParameters &clientParameters, RuntimeContext runtimeContext,
bool clearRuntimeContext, std::vector<void *> &&preparedArgs_,
std::vector<encrypted_scalars_and_sizes_t> &&ciphertextBuffers_)
PublicArguments::PublicArguments(const ClientParameters &clientParameters,
RuntimeContext runtimeContext,
bool clearRuntimeContext,
std::vector<void *> &&preparedArgs_,
std::vector<TensorData> &&ciphertextBuffers_)
: clientParameters(clientParameters), runtimeContext(runtimeContext),
clearRuntimeContext(clearRuntimeContext) {
preparedArgs = std::move(preparedArgs_);
@@ -63,7 +64,7 @@ PublicArguments::serialize(std::ostream &ostream) {
auto aligned = (encrypted_scalars_t)preparedArgs[iPreparedArgs++];
assert(aligned != nullptr);
auto offset = (size_t)preparedArgs[iPreparedArgs++];
std::vector<size_t> sizes; // includes lweSize as last dim
std::vector<int64_t> sizes; // includes lweSize as last dim
sizes.resize(rank + 1);
for (auto dim = 0u; dim < sizes.size(); dim++) {
// sizes are part of the client parameters signature
@@ -91,7 +92,7 @@ PublicArguments::unserializeArgs(std::istream &istream) {
if (!gate.encryption.hasValue()) {
return StringError("Clear values are not handled");
}
auto lweSize = clientParameters.lweSecretKeyParam(gate).lweSize();
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
std::vector<int64_t> sizes = gate.shape.dimensions;
sizes.push_back(lweSize);
ciphertextBuffers.push_back(unserializeTensorData(sizes, istream));
@@ -135,14 +136,17 @@ PublicArguments::unserialize(ClientParameters &clientParameters,
return sArguments;
}
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
PublicResult::decryptVector(KeySet &keySet, size_t pos) {
auto lweSize =
clientParameters.lweSecretKeyParam(clientParameters.outputs[pos])
.lweSize();
outcome::checked<std::vector<uint64_t>, StringError>
PublicResult::asClearTextVector(KeySet &keySet, size_t pos) {
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
if (!gate.isEncrypted()) {
return buffers[pos].values;
}
auto buffer = buffers[pos];
decrypted_tensor_1_t decryptedValues(buffer.length() / lweSize);
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
std::vector<uint64_t> decryptedValues(buffer.length() / lweSize);
for (size_t i = 0; i < decryptedValues.size(); i++) {
auto ciphertext = &buffer.values[i * lweSize];
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decryptedValues[i]));
@@ -150,5 +154,63 @@ PublicResult::decryptVector(KeySet &keySet, size_t pos) {
return decryptedValues;
}
void next_coord_index(size_t index[], size_t sizes[], size_t rank) {
// increase multi dim index
for (int r = rank - 1; r >= 0; r--) {
if (index[r] < sizes[r] - 1) {
index[r]++;
return;
}
index[r] = 0;
}
}
size_t global_index(size_t index[], size_t sizes[], size_t strides[],
size_t rank) {
// compute global index from multi dim index
size_t g_index = 0;
size_t default_stride = 1;
for (int r = rank - 1; r >= 0; r--) {
g_index += index[r] * ((strides[r] == 0) ? default_stride : strides[r]);
default_stride *= sizes[r];
}
return g_index;
}
TensorData tensorDataFromScalar(uint64_t value) { return {{value}, {1}}; }
TensorData tensorDataFromMemRef(size_t memref_rank,
encrypted_scalars_t allocated,
encrypted_scalars_t aligned, size_t offset,
size_t *sizes, size_t *strides) {
TensorData result;
assert(aligned != nullptr);
result.sizes.resize(memref_rank);
for (size_t r = 0; r < memref_rank; r++) {
result.sizes[r] = sizes[r];
}
// ephemeral multi dim index to compute global strides
size_t *index = new size_t[memref_rank];
for (size_t r = 0; r < memref_rank; r++) {
index[r] = 0;
}
auto len = result.length();
result.values.resize(len);
// TODO: add a fast path for dense result (no real strides)
for (size_t i = 0; i < len; i++) {
int g_index = offset + global_index(index, sizes, strides, memref_rank);
result.values[i] = aligned[offset + g_index];
next_coord_index(index, sizes, memref_rank);
}
delete[] index;
// TEMPORARY: That quick and dirty but as this function is used only to
// convert a result of the mlir program and as data are copied here, we
// release the alocated pointer if it set.
if (allocated != nullptr) {
free(allocated);
}
return result;
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -93,7 +93,7 @@ std::ostream &serializeTensorData(uint64_t *values, size_t length,
return ostream;
}
std::ostream &serializeTensorData(std::vector<size_t> &sizes, uint64_t *values,
std::ostream &serializeTensorData(std::vector<int64_t> &sizes, uint64_t *values,
std::ostream &ostream) {
size_t length = 1;
for (auto size : sizes) {
@@ -107,7 +107,7 @@ std::ostream &serializeTensorData(std::vector<size_t> &sizes, uint64_t *values,
std::ostream &serializeTensorData(TensorData &values_and_sizes,
std::ostream &ostream) {
std::vector<size_t> &sizes = values_and_sizes.sizes;
std::vector<int64_t> &sizes = values_and_sizes.sizes;
encrypted_scalars_t values = values_and_sizes.values.data();
return serializeTensorData(sizes, values, ostream);
}

View File

@@ -19,7 +19,7 @@ TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args,
size_t rank) {
using concretelang::clientlib::MemRefDescriptor;
constexpr auto convert = &TensorData_from_MemRef;
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
switch (rank) {
case 0: {
auto m = multi_arity_call((MemRefDescriptor<1>(*)(void *...))func, args);

View File

@@ -24,64 +24,6 @@ using concretelang::clientlib::CircuitGateShape;
using concretelang::clientlib::PublicArguments;
using concretelang::error::StringError;
void next_coord_index(size_t index[], size_t sizes[], size_t rank) {
// increase multi dim index
for (int r = rank - 1; r >= 0; r--) {
if (index[r] < sizes[r] - 1) {
index[r]++;
return;
}
index[r] = 0;
}
}
size_t global_index(size_t index[], size_t sizes[], size_t strides[],
size_t rank) {
// compute global index from multi dim index
size_t g_index = 0;
size_t default_stride = 1;
for (int r = rank - 1; r >= 0; r--) {
g_index += index[r] * ((strides[r] == 0) ? default_stride : strides[r]);
default_stride *= sizes[r];
}
return g_index;
}
/** Helper function to convert from MemRefDescriptor to
* TensorData assuming MemRefDescriptor are bufferized */
TensorData TensorData_from_MemRef(size_t memref_rank,
encrypted_scalars_t allocated,
encrypted_scalars_t aligned, size_t offset,
size_t *sizes, size_t *strides) {
TensorData result;
assert(aligned != nullptr);
result.sizes.resize(memref_rank);
for (size_t r = 0; r < memref_rank; r++) {
result.sizes[r] = sizes[r];
}
size_t *index = new size_t[memref_rank]; // ephemeral multi dim index to
// compute global strides
for (size_t r = 0; r < memref_rank; r++) {
index[r] = 0;
}
auto len = result.length();
result.values.resize(len);
// TODO: add a fast path for dense result (no real strides)
for (size_t i = 0; i < len; i++) {
int g_index = offset + global_index(index, sizes, strides, memref_rank);
result.values[i] = aligned[offset + g_index];
next_coord_index(index, sizes, memref_rank);
}
delete[] index;
// TEMPORARY: That quick and dirty but as this function is used only to
// convert a result of the mlir program and as data are copied here, we
// release the alocated pointer if it set.
if (allocated != nullptr) {
free(allocated);
}
return result;
}
outcome::checked<ServerLambda, StringError>
ServerLambda::loadFromModule(std::shared_ptr<DynamicModule> module,
std::string funcName) {

View File

@@ -69,22 +69,88 @@ llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
<< pos << " is null or missing";
}
llvm::Error JITLambda::invoke(Argument &args) {
size_t expectedInputs = this->type.getNumParams();
size_t actualInputs = args.inputs.size();
if (expectedInputs == actualInputs) {
return invokeRaw(args.rawArg);
}
return StreamStringError("invokeRaw: received ")
<< actualInputs << "arguments instead of " << expectedInputs;
}
// memref is a struct which is flattened aligned, allocated pointers, offset,
// and two array of rank size for sizes and strides.
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) {
return 3 + 2 * rank;
}
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
JITLambda::call(clientlib::PublicArguments &args) {
// invokeRaw needs to have pointers on arguments and a pointers on the result
// as last argument.
// Prepare the outputs vector to store the output value of the lambda.
auto numOutputs = 0;
for (auto &output : args.clientParameters.outputs) {
if (output.shape.dimensions.empty()) {
// scalar gate
if (output.encryption.hasValue()) {
// encrypted scalar : memref<lweSizexi64>
numOutputs += numArgOfRankedMemrefCallingConvention(1);
} else {
// clear scalar
numOutputs += 1;
}
} else {
// memref gate : rank+1 if the output is encrypted for the lwe size
// dimension
auto rank = output.shape.dimensions.size() +
(output.encryption.hasValue() ? 1 : 0);
numOutputs += numArgOfRankedMemrefCallingConvention(rank);
}
}
std::vector<void *> outputs(numOutputs);
// Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on
// inputs and outputs.
std::vector<void *> rawArgs(args.preparedArgs.size() + 1 /*runtime context*/ +
outputs.size());
size_t i = 0;
// Pointers on inputs
for (auto &arg : args.preparedArgs) {
rawArgs[i++] = &arg;
}
// Pointer on runtime context, the rawArgs take pointer on actual value that
// is passed to the compiled function.
auto rtCtxPtr = &args.runtimeContext;
rawArgs[i++] = &rtCtxPtr;
// Pointers on outputs
for (auto &out : outputs) {
rawArgs[i++] = &out;
}
// Invoke
if (auto err = invokeRaw(rawArgs)) {
return std::move(err);
}
// Store the result to the PublicResult
std::vector<clientlib::TensorData> buffers;
{
size_t outputOffset = 0;
for (auto &output : args.clientParameters.outputs) {
if (output.shape.dimensions.empty() && !output.encryption.hasValue()) {
// clear scalar
buffers.push_back(
clientlib::tensorDataFromScalar((uint64_t)outputs[outputOffset++]));
} else {
// encrypted scalar, and tensor gate are memref
auto rank = output.shape.dimensions.size() +
(output.encryption.hasValue() ? 1 : 0);
auto allocated = (uint64_t *)outputs[outputOffset++];
auto aligned = (uint64_t *)outputs[outputOffset++];
auto offset = (size_t)outputs[outputOffset++];
size_t *sizes = (size_t *)&outputs[outputOffset];
outputOffset += rank;
size_t *strides = (size_t *)&outputs[outputOffset];
outputOffset += rank;
buffers.push_back(clientlib::tensorDataFromMemRef(
rank, allocated, aligned, offset, sizes, strides));
}
}
}
return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers);
}
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
// Setting the inputs
auto numInputs = 0;

View File

@@ -129,7 +129,8 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
auto keySet = std::move(keySetOrErr.value());
return Lambda{this->compilationContext, std::move(lambda), std::move(keySet)};
return Lambda{this->compilationContext, std::move(lambda), std::move(keySet),
*compRes.clientParameters};
}
} // namespace concretelang

View File

@@ -38,6 +38,10 @@ compile(std::string outputLib, std::string source,
mlir::concretelang::JitCompilerEngine ce{ccx};
ce.setClientParametersFuncName(funcname);
auto result = ce.compile(sources, outputLib);
if (!result) {
llvm::errs() << result.takeError();
assert(false);
}
assert(result);
return result.get();
}
@@ -72,7 +76,7 @@ func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
auto maybeKeySet = lambda.keySet(getTestKeySetCachePtr(), 0, 0);
ASSERT_TRUE(maybeKeySet.has_value());
std::shared_ptr<KeySet> keySet = std::move(maybeKeySet.value());
auto maybePublicArguments = lambda.publicArguments(1, keySet);
auto maybePublicArguments = lambda.publicArguments(1, *keySet);
ASSERT_TRUE(maybePublicArguments.has_value());
auto publicArguments = std::move(maybePublicArguments.value());
@@ -80,7 +84,7 @@ func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
ASSERT_TRUE(publicArguments->serialize(osstream).has_value());
EXPECT_TRUE(osstream.good());
// Direct call without intermediate
EXPECT_TRUE(lambda.serializeCall(1, keySet, osstream));
EXPECT_TRUE(lambda.serializeCall(1, *keySet, osstream));
EXPECT_TRUE(osstream.good());
}