mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor(jit): Use PublicArguments instead of JitLambda::Argument to call the lambda (uniform calling to ServerLambda and JitLambda)
This commit is contained in:
@@ -82,19 +82,18 @@ public:
|
||||
/// ServerLambda::real_call_write function. ostream must be in binary mode
|
||||
/// std::ios_base::openmode::binary
|
||||
outcome::checked<void, StringError>
|
||||
serializeCall(Args... args, std::shared_ptr<KeySet> keySet,
|
||||
std::ostream &ostream) {
|
||||
serializeCall(Args... args, KeySet &keySet, std::ostream &ostream) {
|
||||
OUTCOME_TRY(auto publicArguments, publicArguments(args..., keySet));
|
||||
return publicArguments->serialize(ostream);
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
publicArguments(Args... args, std::shared_ptr<KeySet> keySet) {
|
||||
publicArguments(Args... args, KeySet &keySet) {
|
||||
OUTCOME_TRY(auto clientArguments,
|
||||
EncryptedArguments::create(keySet, args...));
|
||||
|
||||
return clientArguments->exportPublicArguments(clientParameters,
|
||||
keySet->runtimeContext());
|
||||
keySet.runtimeContext());
|
||||
}
|
||||
|
||||
outcome::checked<Result, StringError> decryptResult(KeySet &keySet,
|
||||
|
||||
@@ -120,6 +120,8 @@ static inline bool operator==(const CircuitGateShape &lhs,
|
||||
struct CircuitGate {
|
||||
llvm::Optional<EncryptionGate> encryption;
|
||||
CircuitGateShape shape;
|
||||
|
||||
bool isEncrypted() { return encryption.hasValue(); }
|
||||
};
|
||||
static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
|
||||
return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape;
|
||||
@@ -140,7 +142,32 @@ struct ClientParameters {
|
||||
|
||||
static std::string getClientParametersPath(std::string path);
|
||||
|
||||
LweSecretKeyParam lweSecretKeyParam(CircuitGate gate);
|
||||
outcome::checked<CircuitGate, StringError> input(size_t pos) {
|
||||
if (pos >= inputs.size()) {
|
||||
return StringError("input gate ") << pos << " didn't exists";
|
||||
}
|
||||
return inputs[pos];
|
||||
}
|
||||
|
||||
outcome::checked<CircuitGate, StringError> ouput(size_t pos) {
|
||||
if (pos >= outputs.size()) {
|
||||
return StringError("output gate ") << pos << " didn't exists";
|
||||
}
|
||||
return outputs[pos];
|
||||
}
|
||||
|
||||
outcome::checked<LweSecretKeyParam, StringError>
|
||||
lweSecretKeyParam(CircuitGate gate) {
|
||||
if (!gate.encryption.hasValue()) {
|
||||
return StringError("gate is not encrypted");
|
||||
}
|
||||
auto secretKey = secretKeys.find(gate.encryption->secretKeyID);
|
||||
if (secretKey == secretKeys.end()) {
|
||||
return StringError("cannot find ")
|
||||
<< gate.encryption->secretKeyID << " in client parameters";
|
||||
}
|
||||
return secretKey->second;
|
||||
}
|
||||
};
|
||||
|
||||
static inline bool operator==(const ClientParameters &lhs,
|
||||
|
||||
@@ -35,12 +35,16 @@ public:
|
||||
/// an EncryptedArguments
|
||||
template <typename... Args>
|
||||
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
|
||||
create(std::shared_ptr<KeySet> keySet, Args... args) {
|
||||
create(KeySet &keySet, Args... args) {
|
||||
auto arguments = std::make_unique<EncryptedArguments>();
|
||||
OUTCOME_TRYV(arguments->pushArgs(keySet, args...));
|
||||
return arguments;
|
||||
}
|
||||
|
||||
static std::unique_ptr<EncryptedArguments> empty() {
|
||||
return std::make_unique<EncryptedArguments>();
|
||||
}
|
||||
|
||||
/// Export encrypted arguments as public arguments, reset the encrypted
|
||||
/// arguments, i.e. move all buffers to the PublicArguments and reset the
|
||||
/// positional counter.
|
||||
@@ -48,31 +52,44 @@ public:
|
||||
exportPublicArguments(ClientParameters clientParameters,
|
||||
RuntimeContext runtimeContext);
|
||||
|
||||
public:
|
||||
/// Add a uint8_t scalar argument.
|
||||
outcome::checked<void, StringError> pushArg(uint8_t arg,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
/// Check that all arguments as been pushed.
|
||||
/// TODO: Remove public method here
|
||||
outcome::checked<void, StringError> checkAllArgs(KeySet &keySet);
|
||||
|
||||
public:
|
||||
// Add a uint64_t scalar argument.
|
||||
outcome::checked<void, StringError> pushArg(uint64_t arg,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
outcome::checked<void, StringError> pushArg(uint64_t arg, KeySet &keySet);
|
||||
|
||||
/// Add a vector-tensor argument.
|
||||
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
KeySet &keySet);
|
||||
|
||||
// Add a 1D tensor argument with data and size of the dimension.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError> pushArg(const T *data, int64_t dim1,
|
||||
KeySet &keySet) {
|
||||
return pushArg(std::vector<uint8_t>(data, data + dim1), keySet);
|
||||
}
|
||||
|
||||
// Add a tensor argument.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(const T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
|
||||
return pushArg(8 * sizeof(T), static_cast<const void *>(data), shape,
|
||||
keySet);
|
||||
}
|
||||
|
||||
/// Add a 1D tensor argument.
|
||||
template <size_t size>
|
||||
outcome::checked<void, StringError> pushArg(std::array<uint8_t, size> arg,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
KeySet &keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {size}, keySet);
|
||||
}
|
||||
|
||||
/// Add a 2D tensor argument.
|
||||
template <size_t size0, size_t size1>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(std::array<std::array<uint8_t, size1>, size0> arg,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
pushArg(std::array<std::array<uint8_t, size1>, size0> arg, KeySet &keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {size0, size1}, keySet);
|
||||
}
|
||||
|
||||
@@ -80,7 +97,7 @@ public:
|
||||
template <size_t size0, size_t size1, size_t size2>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
KeySet &keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {size0, size1, size2}, keySet);
|
||||
}
|
||||
|
||||
@@ -88,41 +105,48 @@ public:
|
||||
|
||||
// Set a argument at the given pos as a 1D tensor of T.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError> pushArg(T *data, size_t dim1,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg<T>(data, llvm::ArrayRef<size_t>(&dim1, 1), keySet);
|
||||
outcome::checked<void, StringError> pushArg(T *data, int64_t dim1,
|
||||
KeySet &keySet) {
|
||||
return pushArg<T>(data, llvm::ArrayRef<int64_t>(&dim1, 1), keySet);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of T.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError> pushArg(T *data,
|
||||
llvm::ArrayRef<int64_t> shape,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg(8 * sizeof(T), static_cast<void *>(data), shape, keySet);
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
|
||||
return pushArg(8 * sizeof(T), static_cast<const void *>(data), shape,
|
||||
keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> pushArg(size_t width, void *data,
|
||||
outcome::checked<void, StringError> pushArg(size_t width, const void *data,
|
||||
llvm::ArrayRef<int64_t> shape,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
KeySet &keySet);
|
||||
|
||||
/// Push a variadic list of arguments.
|
||||
// Recursive case for scalars: extract first scalar argument from
|
||||
// parameter pack and forward rest
|
||||
template <typename Arg0, typename... OtherArgs>
|
||||
outcome::checked<void, StringError> pushArgs(std::shared_ptr<KeySet> keySet,
|
||||
Arg0 arg0, OtherArgs... others) {
|
||||
outcome::checked<void, StringError> pushArgs(KeySet &keySet, Arg0 arg0,
|
||||
OtherArgs... others) {
|
||||
OUTCOME_TRYV(pushArg(arg0, keySet));
|
||||
return pushArgs(keySet, others...);
|
||||
}
|
||||
|
||||
// Recursive case for tensors: extract pointer and size from
|
||||
// parameter pack and forward rest
|
||||
template <typename Arg0, typename... OtherArgs>
|
||||
outcome::checked<void, StringError>
|
||||
pushArgs(KeySet &keySet, Arg0 *arg0, size_t size, OtherArgs... others) {
|
||||
OUTCOME_TRYV(pushArg(arg0, size, keySet));
|
||||
return pushArgs(keySet, others...);
|
||||
}
|
||||
|
||||
// Terminal case of pushArgs
|
||||
outcome::checked<void, StringError> pushArgs(std::shared_ptr<KeySet> keySet) {
|
||||
outcome::checked<void, StringError> pushArgs(KeySet &keySet) {
|
||||
return checkAllArgs(keySet);
|
||||
}
|
||||
|
||||
private:
|
||||
outcome::checked<void, StringError>
|
||||
checkPushTooManyArgs(std::shared_ptr<KeySet> keySet);
|
||||
outcome::checked<void, StringError>
|
||||
checkAllArgs(std::shared_ptr<KeySet> keySet);
|
||||
outcome::checked<void, StringError> checkPushTooManyArgs(KeySet &keySet);
|
||||
|
||||
private:
|
||||
// Position of the next pushed argument
|
||||
|
||||
@@ -21,6 +21,11 @@ namespace serverlib {
|
||||
class ServerLambda;
|
||||
}
|
||||
} // namespace concretelang
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
class JITLambda;
|
||||
}
|
||||
} // namespace mlir
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
@@ -45,7 +50,8 @@ public:
|
||||
outcome::checked<void, StringError> serialize(std::ostream &ostream);
|
||||
|
||||
private:
|
||||
friend class ::concretelang::serverlib::ServerLambda; // from ServerLib
|
||||
friend class ::concretelang::serverlib::ServerLambda;
|
||||
friend class ::mlir::concretelang::JITLambda;
|
||||
|
||||
outcome::checked<void, StringError> unserializeArgs(std::istream &istream);
|
||||
|
||||
@@ -82,16 +88,27 @@ struct PublicResult {
|
||||
/// Serialize into an output stream.
|
||||
outcome::checked<void, StringError> serialize(std::ostream &ostream);
|
||||
|
||||
/// Decrypt the result at `pos` as a vector.
|
||||
/// Get the result at `pos` as a vector, if the result is a scalar returns a
|
||||
/// vector of size 1. Decryption happens if the result is encrypted.
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
decryptVector(KeySet &keySet, size_t pos);
|
||||
asClearTextVector(KeySet &keySet, size_t pos);
|
||||
|
||||
private:
|
||||
// private: TODO tmp
|
||||
friend class ::concretelang::serverlib::ServerLambda;
|
||||
ClientParameters clientParameters;
|
||||
std::vector<TensorData> buffers;
|
||||
};
|
||||
|
||||
/// Helper function to convert from a scalar to TensorData
|
||||
TensorData tensorDataFromScalar(uint64_t value);
|
||||
|
||||
/// Helper function to convert from MemRefDescriptor to
|
||||
/// TensorData
|
||||
TensorData tensorDataFromMemRef(size_t memref_rank,
|
||||
encrypted_scalars_t allocated,
|
||||
encrypted_scalars_t aligned, size_t offset,
|
||||
size_t *sizes, size_t *strides);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ std::ostream &operator<<(std::ostream &ostream,
|
||||
const RuntimeContext &runtimeContext);
|
||||
std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext);
|
||||
|
||||
std::ostream &serializeTensorData(std::vector<size_t> &sizes, uint64_t *values,
|
||||
std::ostream &serializeTensorData(std::vector<int64_t> &sizes, uint64_t *values,
|
||||
std::ostream &ostream);
|
||||
|
||||
std::ostream &serializeTensorData(TensorData &values_and_sizes,
|
||||
|
||||
@@ -35,7 +35,7 @@ using encrypted_scalars_t = uint64_t *;
|
||||
|
||||
struct TensorData {
|
||||
std::vector<uint64_t> values; // tensor of rank r + 1
|
||||
std::vector<size_t> sizes; // r sizes
|
||||
std::vector<int64_t> sizes; // r sizes
|
||||
|
||||
inline size_t length() {
|
||||
if (sizes.empty()) {
|
||||
|
||||
@@ -23,10 +23,6 @@ using concretelang::clientlib::encrypted_scalar_t;
|
||||
using concretelang::clientlib::encrypted_scalars_t;
|
||||
using concretelang::clientlib::TensorData;
|
||||
|
||||
TensorData TensorData_from_MemRef(size_t rank, encrypted_scalars_t allocated,
|
||||
encrypted_scalars_t aligned, size_t offset,
|
||||
size_t *sizes, size_t *strides);
|
||||
|
||||
/// ServerLambda is a utility class that allows to call a function of a
|
||||
/// compilation result.
|
||||
class ServerLambda {
|
||||
|
||||
@@ -11,12 +11,14 @@
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
|
||||
#include <concretelang/ClientLib/KeySet.h>
|
||||
#include <concretelang/ClientLib/PublicArguments.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
using ::concretelang::clientlib::CircuitGate;
|
||||
using ::concretelang::clientlib::KeySet;
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
|
||||
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
|
||||
/// of the module.
|
||||
@@ -118,6 +120,11 @@ public:
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
|
||||
|
||||
/// Call the JIT lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
call(clientlib::PublicArguments &args);
|
||||
|
||||
private:
|
||||
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
|
||||
/// used to store the result of the computation.
|
||||
/// Example:
|
||||
@@ -127,9 +134,6 @@ public:
|
||||
/// lambda.invokeRaw(args);
|
||||
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
|
||||
|
||||
/// invoke the jit lambda with the Argument.
|
||||
llvm::Error invoke(Argument &args);
|
||||
|
||||
private:
|
||||
mlir::LLVM::LLVMFunctionType type;
|
||||
std::string name;
|
||||
|
||||
@@ -17,6 +17,7 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
using ::concretelang::clientlib::KeySetCache;
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
|
||||
namespace {
|
||||
// Generic function template as well as specializations of
|
||||
@@ -26,34 +27,35 @@ namespace {
|
||||
// Helper function for `JitCompilerEngine::Lambda::operator()`
|
||||
// implementing type-dependent preparation of the result.
|
||||
template <typename ResT>
|
||||
llvm::Expected<ResT> typedResult(JITLambda::Argument &arguments);
|
||||
llvm::Expected<ResT> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result);
|
||||
|
||||
// Specialization of `typedResult()` for scalar results, forwarding
|
||||
// scalar value to caller
|
||||
template <>
|
||||
inline llvm::Expected<uint64_t> typedResult(JITLambda::Argument &arguments) {
|
||||
uint64_t res = 0;
|
||||
|
||||
if (auto err = arguments.getResult(0, res))
|
||||
return StreamStringError() << "Cannot retrieve result:" << err;
|
||||
|
||||
return res;
|
||||
inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
auto clearResult = result.asClearTextVector(keySet, 0);
|
||||
if (!clearResult.has_value()) {
|
||||
return StreamStringError("typedResult cannot get clear text vector")
|
||||
<< clearResult.error().mesg;
|
||||
}
|
||||
if (clearResult.value().size() != 1) {
|
||||
return StreamStringError("typedResult expect only one value but got ")
|
||||
<< clearResult.value().size();
|
||||
}
|
||||
return clearResult.value()[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline llvm::Expected<std::vector<T>>
|
||||
typedVectorResult(JITLambda::Argument &arguments) {
|
||||
llvm::Expected<size_t> n = arguments.getResultVectorSize(0);
|
||||
|
||||
if (auto err = n.takeError())
|
||||
return std::move(err);
|
||||
|
||||
std::vector<T> res(*n);
|
||||
|
||||
if (auto err = arguments.getResult(0, res.data(), res.size()))
|
||||
return StreamStringError() << "Cannot retrieve result:" << err;
|
||||
|
||||
return std::move(res);
|
||||
typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
auto clearResult = result.asClearTextVector(keySet, 0);
|
||||
if (!clearResult.has_value()) {
|
||||
return StreamStringError("typedVectorResult cannot get clear text vector")
|
||||
<< clearResult.error().mesg;
|
||||
}
|
||||
return std::move(clearResult.value());
|
||||
}
|
||||
|
||||
// Specializations of `typedResult()` for vector results, initializing
|
||||
@@ -62,151 +64,144 @@ typedVectorResult(JITLambda::Argument &arguments) {
|
||||
//
|
||||
// Cannot factor out into a template template <typename T> inline
|
||||
// llvm::Expected<std::vector<uint8_t>>
|
||||
// typedResult(JITLambda::Argument &arguments); due to ambiguity with
|
||||
// scalar template
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint8_t>>
|
||||
typedResult(JITLambda::Argument &arguments) {
|
||||
return typedVectorResult<uint8_t>(arguments);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint16_t>>
|
||||
typedResult(JITLambda::Argument &arguments) {
|
||||
return typedVectorResult<uint16_t>(arguments);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint32_t>>
|
||||
typedResult(JITLambda::Argument &arguments) {
|
||||
return typedVectorResult<uint32_t>(arguments);
|
||||
}
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
|
||||
// to ambiguity with scalar template
|
||||
// template <>
|
||||
// inline llvm::Expected<std::vector<uint8_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// return typedVectorResult<uint8_t>(keySet, result);
|
||||
// }
|
||||
// template <>
|
||||
// inline llvm::Expected<std::vector<uint16_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// return typedVectorResult<uint16_t>(keySet, result);
|
||||
// }
|
||||
// template <>
|
||||
// inline llvm::Expected<std::vector<uint32_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// return typedVectorResult<uint32_t>(keySet, result);
|
||||
// }
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint64_t>>
|
||||
typedResult(JITLambda::Argument &arguments) {
|
||||
return typedVectorResult<uint64_t>(arguments);
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<uint64_t>(keySet, result);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
llvm::Expected<std::unique_ptr<LambdaArgument>>
|
||||
buildTensorLambdaResult(JITLambda::Argument &arguments) {
|
||||
buildTensorLambdaResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
llvm::Expected<std::vector<T>> tensorOrError =
|
||||
typedResult<std::vector<T>>(arguments);
|
||||
typedResult<std::vector<T>>(keySet, result);
|
||||
|
||||
if (!tensorOrError)
|
||||
return std::move(tensorOrError.takeError());
|
||||
if (auto err = tensorOrError.takeError())
|
||||
return std::move(err);
|
||||
std::vector<int64_t> tensorDim(result.buffers[0].sizes.begin(),
|
||||
result.buffers[0].sizes.end() - 1);
|
||||
|
||||
llvm::Expected<std::vector<int64_t>> tensorDimOrError =
|
||||
arguments.getResultDimensions(0);
|
||||
|
||||
if (!tensorDimOrError)
|
||||
return tensorDimOrError.takeError();
|
||||
|
||||
return std::move(std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
|
||||
*tensorOrError, *tensorDimOrError));
|
||||
return std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
|
||||
*tensorOrError, tensorDim);
|
||||
}
|
||||
|
||||
// Specialization of `typedResult()` for a single result wrapped into
|
||||
// a `LambdaArgument`.
|
||||
template <>
|
||||
inline llvm::Expected<std::unique_ptr<LambdaArgument>>
|
||||
typedResult(JITLambda::Argument &arguments) {
|
||||
llvm::Expected<enum JITLambda::Argument::ResultType> resTy =
|
||||
arguments.getResultType(0);
|
||||
|
||||
if (!resTy)
|
||||
return resTy.takeError();
|
||||
|
||||
switch (*resTy) {
|
||||
case JITLambda::Argument::ResultType::SCALAR: {
|
||||
uint64_t res;
|
||||
|
||||
if (llvm::Error err = arguments.getResult(0, res))
|
||||
return std::move(err);
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
auto gate = keySet.outputGate(0);
|
||||
// scalar case
|
||||
if (gate.shape.dimensions.empty()) {
|
||||
auto clearResult = result.asClearTextVector(keySet, 0);
|
||||
if (clearResult.has_error()) {
|
||||
return StreamStringError("typedResult: ") << clearResult.error().mesg;
|
||||
}
|
||||
auto res = clearResult.value()[0];
|
||||
|
||||
return std::make_unique<IntLambdaArgument<uint64_t>>(res);
|
||||
}
|
||||
// tensor case
|
||||
// auto width = gate.shape.width;
|
||||
|
||||
case JITLambda::Argument::ResultType::TENSOR: {
|
||||
llvm::Expected<size_t> width = arguments.getResultWidth(0);
|
||||
// if (width > 32)
|
||||
return buildTensorLambdaResult<uint64_t>(keySet, result);
|
||||
// else if (width > 16)
|
||||
// return buildTensorLambdaResult<uint32_t>(keySet, result);
|
||||
// else if (width > 8)
|
||||
// return buildTensorLambdaResult<uint16_t>(keySet, result);
|
||||
// else if (width <= 8)
|
||||
// return buildTensorLambdaResult<uint8_t>(keySet, result);
|
||||
|
||||
if (!width)
|
||||
return width.takeError();
|
||||
// return StreamStringError("Cannot handle scalars with more than 64 bits");
|
||||
}
|
||||
|
||||
if (*width > 64)
|
||||
return StreamStringError("Cannot handle scalars with more than 64 bits");
|
||||
if (*width > 32)
|
||||
return buildTensorLambdaResult<uint64_t>(arguments);
|
||||
else if (*width > 16)
|
||||
return buildTensorLambdaResult<uint32_t>(arguments);
|
||||
else if (*width > 8)
|
||||
return buildTensorLambdaResult<uint16_t>(arguments);
|
||||
else if (*width <= 8)
|
||||
return buildTensorLambdaResult<uint8_t>(arguments);
|
||||
}
|
||||
}
|
||||
|
||||
return StreamStringError("Unknown result type");
|
||||
} // namespace
|
||||
|
||||
// Adaptor class that adds arguments specified as instances of
|
||||
// `LambdaArgument` to `JitLambda::Argument`.
|
||||
// Adaptor class that push arguments specified as instances of
|
||||
// `LambdaArgument` to `clientlib::EncryptedArguments`.
|
||||
class JITLambdaArgumentAdaptor {
|
||||
public:
|
||||
// Checks if the argument `arg` is an plaintext / encrypted integer
|
||||
// argument or a plaintext / encrypted tensor argument with a
|
||||
// backing integer type `IntT` and adds the argument to `jla` at
|
||||
// position `pos`.
|
||||
// backing integer type `IntT` and push the argument to `encryptedArgs`.
|
||||
//
|
||||
// Returns `true` if `arg` has one of the types above and its value
|
||||
// was successfully added to `jla`, `false` if none of the types
|
||||
// was successfully added to `encryptedArgs`, `false` if none of the types
|
||||
// matches or an error if a type matched, but adding the argument to
|
||||
// `jla` failed.
|
||||
// `encryptedArgs` failed.
|
||||
template <typename IntT>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
|
||||
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
|
||||
if (llvm::Error err = jla.setArg(pos, ila->getValue()))
|
||||
return std::move(err);
|
||||
else
|
||||
auto res = encryptedArgs.pushArg(ila->getValue(), keySet);
|
||||
if (!res.has_value()) {
|
||||
return StreamStringError(res.error().mesg);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
} else if (auto tla = arg.dyn_cast<
|
||||
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
|
||||
if (llvm::Error err =
|
||||
jla.setArg(pos, tla->getValue(), tla->getDimensions()))
|
||||
return std::move(err);
|
||||
else
|
||||
auto res =
|
||||
encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet);
|
||||
if (!res.has_value()) {
|
||||
return StreamStringError(res.error().mesg);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Recursive case for `tryAddArg<IntT>(...)`
|
||||
template <typename IntT, typename NextIntT, typename... IntTs>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
|
||||
llvm::Expected<bool> successOrError = tryAddArg<IntT>(jla, pos, arg);
|
||||
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
llvm::Expected<bool> successOrError =
|
||||
tryAddArg<IntT>(encryptedArgs, arg, keySet);
|
||||
|
||||
if (!successOrError)
|
||||
return successOrError.takeError();
|
||||
|
||||
if (successOrError.get() == false)
|
||||
return tryAddArg<NextIntT, IntTs...>(jla, pos, arg);
|
||||
return tryAddArg<NextIntT, IntTs...>(encryptedArgs, arg, keySet);
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
// Attempts to add a single argument `arg` to `jla` at position
|
||||
// `pos`. Returns an error if either the argument type is
|
||||
// unsupported or if the argument types is supported, but adding it
|
||||
// to `jla` failed.
|
||||
static inline llvm::Error addArgument(JITLambda::Argument &jla, size_t pos,
|
||||
const LambdaArgument &arg) {
|
||||
// Attempts to push a single argument `arg` to `encryptedArgs`. Returns an
|
||||
// error if either the argument type is unsupported or if the argument types
|
||||
// is supported, but adding it to `encryptedArgs` failed.
|
||||
static inline llvm::Error
|
||||
addArgument(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
// Try the supported integer types; size_t needs explicit
|
||||
// treatment, since it may alias none of the fixed size integer
|
||||
// types
|
||||
llvm::Expected<bool> successOrError =
|
||||
JITLambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t,
|
||||
uint8_t, size_t>(jla, pos, arg);
|
||||
uint8_t, size_t>(encryptedArgs, arg,
|
||||
keySet);
|
||||
|
||||
if (!successOrError)
|
||||
return successOrError.takeError();
|
||||
@@ -217,7 +212,6 @@ public:
|
||||
return llvm::Error::success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// A compiler engine that JIT-compiles a source and produces a lambda
|
||||
// object directly invocable through its call operator.
|
||||
@@ -231,12 +225,15 @@ public:
|
||||
Lambda(Lambda &&other)
|
||||
: innerLambda(std::move(other.innerLambda)),
|
||||
keySet(std::move(other.keySet)),
|
||||
compilationContext(other.compilationContext) {}
|
||||
compilationContext(other.compilationContext),
|
||||
clientParameters(other.clientParameters) {}
|
||||
|
||||
Lambda(std::shared_ptr<CompilationContext> compilationContext,
|
||||
std::unique_ptr<JITLambda> lambda, std::unique_ptr<KeySet> keySet)
|
||||
std::unique_ptr<JITLambda> lambda, std::unique_ptr<KeySet> keySet,
|
||||
clientlib::ClientParameters clientParameters)
|
||||
: innerLambda(std::move(lambda)), keySet(std::move(keySet)),
|
||||
compilationContext(compilationContext) {}
|
||||
compilationContext(compilationContext),
|
||||
clientParameters(clientParameters) {}
|
||||
|
||||
// Returns the number of arguments required for an invocation of
|
||||
// the lambda
|
||||
@@ -251,81 +248,96 @@ public:
|
||||
template <typename ResT = uint64_t>
|
||||
llvm::Expected<ResT>
|
||||
operator()(llvm::ArrayRef<LambdaArgument *> lambdaArgs) {
|
||||
// Create the arguments of the JIT lambda
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
|
||||
mlir::concretelang::JITLambda::Argument::create(*this->keySet.get());
|
||||
|
||||
if (llvm::Error err = argsOrErr.takeError())
|
||||
return StreamStringError("Could not create lambda arguments");
|
||||
|
||||
// Set the arguments
|
||||
std::unique_ptr<JITLambda::Argument> arguments =
|
||||
std::move(argsOrErr.get());
|
||||
// Encrypt the arguments
|
||||
auto encryptedArgs = clientlib::EncryptedArguments::empty();
|
||||
|
||||
for (size_t i = 0; i < lambdaArgs.size(); i++) {
|
||||
if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument(
|
||||
*arguments, i, *lambdaArgs[i])) {
|
||||
*encryptedArgs, *lambdaArgs[i], *this->keySet)) {
|
||||
return std::move(err);
|
||||
}
|
||||
}
|
||||
|
||||
// Invoke the lambda
|
||||
if (auto err = this->innerLambda->invoke(*arguments))
|
||||
return StreamStringError() << "Cannot invoke lambda:" << err;
|
||||
auto check = encryptedArgs->checkAllArgs(*this->keySet);
|
||||
if (check.has_error()) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
|
||||
return std::move(typedResult<ResT>(*arguments));
|
||||
// Export as public arguments
|
||||
auto publicArguments = encryptedArgs->exportPublicArguments(
|
||||
clientParameters, keySet->runtimeContext());
|
||||
if (!publicArguments.has_value()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
|
||||
// Call the lambda
|
||||
auto publicResult = this->innerLambda->call(*publicArguments.value());
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
return typedResult<ResT>(*keySet, **publicResult);
|
||||
}
|
||||
|
||||
// Invocation with an array of arguments of the same type
|
||||
template <typename T, typename ResT = uint64_t>
|
||||
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
|
||||
// Create the arguments of the JIT lambda
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
|
||||
mlir::concretelang::JITLambda::Argument::create(*this->keySet.get());
|
||||
|
||||
if (llvm::Error err = argsOrErr.takeError())
|
||||
return StreamStringError("Could not create lambda arguments");
|
||||
|
||||
// Set the arguments
|
||||
std::unique_ptr<JITLambda::Argument> arguments =
|
||||
std::move(argsOrErr.get());
|
||||
// Encrypt the arguments
|
||||
auto encryptedArgs = clientlib::EncryptedArguments::empty();
|
||||
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (auto err = arguments->setArg(i, args[i])) {
|
||||
return StreamStringError()
|
||||
<< "Cannot push argument " << i << ": " << err;
|
||||
auto res = encryptedArgs->pushArg(args[i], *keySet);
|
||||
if (res.has_error()) {
|
||||
return StreamStringError(res.error().mesg);
|
||||
}
|
||||
}
|
||||
|
||||
// Invoke the lambda
|
||||
if (auto err = this->innerLambda->invoke(*arguments))
|
||||
return StreamStringError() << "Cannot invoke lambda:" << err;
|
||||
auto check = encryptedArgs->checkAllArgs(*this->keySet);
|
||||
if (check.has_error()) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
|
||||
return std::move(typedResult<ResT>(*arguments));
|
||||
// Export as public arguments
|
||||
auto publicArguments = encryptedArgs->exportPublicArguments(
|
||||
clientParameters, keySet->runtimeContext());
|
||||
if (!publicArguments.has_value()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
|
||||
// Call the lambda
|
||||
auto publicResult = this->innerLambda->call(*publicArguments.value());
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
return typedResult<ResT>(*keySet, **publicResult);
|
||||
}
|
||||
|
||||
// Invocation with arguments of different types
|
||||
template <typename ResT = uint64_t, typename... Ts>
|
||||
llvm::Expected<ResT> operator()(const Ts... ts) {
|
||||
// Create the arguments of the JIT lambda
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
|
||||
mlir::concretelang::JITLambda::Argument::create(*this->keySet.get());
|
||||
// Encrypt the arguments
|
||||
auto encryptedArgs =
|
||||
clientlib::EncryptedArguments::create(*keySet, ts...);
|
||||
|
||||
if (llvm::Error err = argsOrErr.takeError())
|
||||
return StreamStringError("Could not create lambda arguments");
|
||||
if (encryptedArgs.has_error()) {
|
||||
return StreamStringError(encryptedArgs.error().mesg);
|
||||
}
|
||||
|
||||
// Set the arguments
|
||||
std::unique_ptr<JITLambda::Argument> arguments =
|
||||
std::move(argsOrErr.get());
|
||||
// Export as public arguments
|
||||
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
|
||||
clientParameters, keySet->runtimeContext());
|
||||
if (!publicArguments.has_value()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
|
||||
if (llvm::Error err = this->addArgs<0>(arguments.get(), ts...))
|
||||
// Call the lambda
|
||||
auto publicResult = this->innerLambda->call(*publicArguments.value());
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
// Invoke the lambda
|
||||
if (auto err = this->innerLambda->invoke(*arguments))
|
||||
return StreamStringError() << "Cannot invoke lambda:" << err;
|
||||
|
||||
return std::move(typedResult<ResT>(*arguments));
|
||||
return typedResult<ResT>(*keySet, **publicResult);
|
||||
}
|
||||
|
||||
protected:
|
||||
@@ -364,6 +376,7 @@ public:
|
||||
std::unique_ptr<JITLambda> innerLambda;
|
||||
std::unique_ptr<KeySet> keySet;
|
||||
std::shared_ptr<CompilationContext> compilationContext;
|
||||
const clientlib::ClientParameters clientParameters;
|
||||
};
|
||||
|
||||
JitCompilerEngine(std::shared_ptr<CompilationContext> compilationContext =
|
||||
|
||||
@@ -63,12 +63,29 @@ public:
|
||||
keySet(keySet) {}
|
||||
|
||||
outcome::checked<Result, StringError> call(Args... args) {
|
||||
// std::string message;
|
||||
|
||||
// client stream
|
||||
// std::ostringstream clientOuput(std::ios::binary);
|
||||
// client argument encryption
|
||||
OUTCOME_TRY(auto encryptedArgs,
|
||||
clientlib::EncryptedArguments::create(keySet, args...));
|
||||
clientlib::EncryptedArguments::create(*keySet, args...));
|
||||
OUTCOME_TRY(auto publicArgument,
|
||||
encryptedArgs->exportPublicArguments(this->clientParameters,
|
||||
keySet->runtimeContext()));
|
||||
// client argument serialization
|
||||
// publicArgument->serialize(clientOuput);
|
||||
// message = clientOuput.str();
|
||||
|
||||
// server stream
|
||||
// std::istringstream serverInput(message, std::ios::binary);
|
||||
// freeStringMemory(message);
|
||||
//
|
||||
// OUTCOME_TRY(auto publicArguments,
|
||||
// clientlib::PublicArguments::unserialize(
|
||||
// this->clientParameters,
|
||||
// serverInput));
|
||||
|
||||
// server function call
|
||||
auto publicResult = serverLambda.call(*publicArgument);
|
||||
|
||||
|
||||
@@ -24,11 +24,9 @@ ClientLambda::load(std::string functionName, std::string jsonPath) {
|
||||
return StringError("ClientLambda: cannot find function ")
|
||||
<< functionName << " in client parameters" << jsonPath;
|
||||
}
|
||||
|
||||
if (param->outputs.size() != 1) {
|
||||
return StringError("ClientLambda: output arity (")
|
||||
<< std::to_string(param->outputs.size())
|
||||
<< ") != 1 is not supported";
|
||||
<< std::to_string(param->outputs.size()) << ") != 1 is not supprted";
|
||||
}
|
||||
|
||||
if (!param->outputs[0].encryption.hasValue()) {
|
||||
@@ -54,7 +52,7 @@ ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) {
|
||||
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) {
|
||||
return result.decryptVector(keySet, 0);
|
||||
return result.asClearTextVector(keySet, 0);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> errorResultRank(size_t expected,
|
||||
|
||||
@@ -53,10 +53,6 @@ std::size_t ClientParameters::hash() {
|
||||
return currentHash;
|
||||
}
|
||||
|
||||
LweSecretKeyParam ClientParameters::lweSecretKeyParam(CircuitGate gate) {
|
||||
return secretKeys.find(gate.encryption->secretKeyID)->second;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const LweSecretKeyParam &v) {
|
||||
llvm::json::Object object{
|
||||
{"dimension", v.dimension},
|
||||
|
||||
@@ -22,34 +22,25 @@ EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArguments::pushArg(uint8_t arg, std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg((uint64_t)arg, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArguments::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
|
||||
EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
|
||||
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
|
||||
auto pos = currentPos++;
|
||||
CircuitGate input = keySet->inputGate(pos);
|
||||
CircuitGate input = keySet.inputGate(pos);
|
||||
if (input.shape.size != 0) {
|
||||
return StringError("argument #") << pos << " is not a scalar";
|
||||
}
|
||||
if (!input.encryption.hasValue()) {
|
||||
// clear scalar: just push the argument
|
||||
if (input.shape.width != 64) {
|
||||
return StringError(
|
||||
"scalar argument of with != 64 is not supported for DynamicLambda");
|
||||
}
|
||||
preparedArgs.push_back((void *)arg);
|
||||
return outcome::success();
|
||||
}
|
||||
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
|
||||
TensorData &values_and_sizes = ciphertextBuffers.back();
|
||||
auto lweSize = keySet->getInputLweSecretKeyParam(pos).lweSize();
|
||||
auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize();
|
||||
values_and_sizes.sizes.push_back(lweSize);
|
||||
values_and_sizes.values.resize(lweSize);
|
||||
|
||||
OUTCOME_TRYV(keySet->encrypt_lwe(pos, values_and_sizes.values.data(), arg));
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values_and_sizes.values.data(), arg));
|
||||
// Note: Since we bufferized lwe ciphertext take care of memref calling
|
||||
// convention
|
||||
// allocated
|
||||
@@ -66,18 +57,16 @@ EncryptedArguments::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArguments::pushArg(std::vector<uint8_t> arg,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
EncryptedArguments::pushArg(std::vector<uint8_t> arg, KeySet &keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()}, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArguments::pushArg(size_t width, void *data,
|
||||
llvm::ArrayRef<int64_t> shape,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
EncryptedArguments::pushArg(size_t width, const void *data,
|
||||
llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
|
||||
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
|
||||
auto pos = currentPos;
|
||||
CircuitGate input = keySet->inputGate(pos);
|
||||
CircuitGate input = keySet.inputGate(pos);
|
||||
// Check the width of data
|
||||
if (input.shape.width > 64) {
|
||||
return StringError("argument #")
|
||||
@@ -108,7 +97,7 @@ EncryptedArguments::pushArg(size_t width, void *data,
|
||||
}
|
||||
}
|
||||
if (input.encryption.hasValue()) {
|
||||
auto lweSize = keySet->getInputLweSecretKeyParam(pos).lweSize();
|
||||
auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize();
|
||||
values_and_sizes.sizes.push_back(lweSize);
|
||||
|
||||
// Encrypted tensor: for now we support only 8 bits for encrypted tensor
|
||||
@@ -124,9 +113,14 @@ EncryptedArguments::pushArg(size_t width, void *data,
|
||||
// Allocate ciphertexts and encrypt, for every values in tensor
|
||||
for (size_t i = 0, offset = 0; i < input.shape.size;
|
||||
i++, offset += lweSize) {
|
||||
OUTCOME_TRYV(keySet->encrypt_lwe(pos, values.data() + offset, data8[i]));
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data8[i]));
|
||||
}
|
||||
} // TODO: NON ENCRYPTED, COPY CONTENT TO values_and_sizes
|
||||
} else {
|
||||
values_and_sizes.values.resize(input.shape.size);
|
||||
for (size_t i = 0; i < input.shape.size; i++) {
|
||||
values_and_sizes.values[i] = ((const uint64_t *)data)[i];
|
||||
}
|
||||
}
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
@@ -150,8 +144,8 @@ EncryptedArguments::pushArg(size_t width, void *data,
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArguments::checkPushTooManyArgs(std::shared_ptr<KeySet> keySet) {
|
||||
size_t arity = keySet->numInputs();
|
||||
EncryptedArguments::checkPushTooManyArgs(KeySet &keySet) {
|
||||
size_t arity = keySet.numInputs();
|
||||
if (currentPos < arity) {
|
||||
return outcome::success();
|
||||
}
|
||||
@@ -160,8 +154,8 @@ EncryptedArguments::checkPushTooManyArgs(std::shared_ptr<KeySet> keySet) {
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArguments::checkAllArgs(std::shared_ptr<KeySet> keySet) {
|
||||
size_t arity = keySet->numInputs();
|
||||
EncryptedArguments::checkAllArgs(KeySet &keySet) {
|
||||
size_t arity = keySet.numInputs();
|
||||
if (currentPos == arity) {
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -19,10 +19,11 @@ namespace clientlib {
|
||||
using concretelang::error::StringError;
|
||||
|
||||
// TODO: optimize the move
|
||||
PublicArguments::PublicArguments(
|
||||
const ClientParameters &clientParameters, RuntimeContext runtimeContext,
|
||||
bool clearRuntimeContext, std::vector<void *> &&preparedArgs_,
|
||||
std::vector<encrypted_scalars_and_sizes_t> &&ciphertextBuffers_)
|
||||
PublicArguments::PublicArguments(const ClientParameters &clientParameters,
|
||||
RuntimeContext runtimeContext,
|
||||
bool clearRuntimeContext,
|
||||
std::vector<void *> &&preparedArgs_,
|
||||
std::vector<TensorData> &&ciphertextBuffers_)
|
||||
: clientParameters(clientParameters), runtimeContext(runtimeContext),
|
||||
clearRuntimeContext(clearRuntimeContext) {
|
||||
preparedArgs = std::move(preparedArgs_);
|
||||
@@ -63,7 +64,7 @@ PublicArguments::serialize(std::ostream &ostream) {
|
||||
auto aligned = (encrypted_scalars_t)preparedArgs[iPreparedArgs++];
|
||||
assert(aligned != nullptr);
|
||||
auto offset = (size_t)preparedArgs[iPreparedArgs++];
|
||||
std::vector<size_t> sizes; // includes lweSize as last dim
|
||||
std::vector<int64_t> sizes; // includes lweSize as last dim
|
||||
sizes.resize(rank + 1);
|
||||
for (auto dim = 0u; dim < sizes.size(); dim++) {
|
||||
// sizes are part of the client parameters signature
|
||||
@@ -91,7 +92,7 @@ PublicArguments::unserializeArgs(std::istream &istream) {
|
||||
if (!gate.encryption.hasValue()) {
|
||||
return StringError("Clear values are not handled");
|
||||
}
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).lweSize();
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
|
||||
std::vector<int64_t> sizes = gate.shape.dimensions;
|
||||
sizes.push_back(lweSize);
|
||||
ciphertextBuffers.push_back(unserializeTensorData(sizes, istream));
|
||||
@@ -135,14 +136,17 @@ PublicArguments::unserialize(ClientParameters &clientParameters,
|
||||
return sArguments;
|
||||
}
|
||||
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
PublicResult::decryptVector(KeySet &keySet, size_t pos) {
|
||||
auto lweSize =
|
||||
clientParameters.lweSecretKeyParam(clientParameters.outputs[pos])
|
||||
.lweSize();
|
||||
outcome::checked<std::vector<uint64_t>, StringError>
|
||||
PublicResult::asClearTextVector(KeySet &keySet, size_t pos) {
|
||||
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
|
||||
if (!gate.isEncrypted()) {
|
||||
return buffers[pos].values;
|
||||
}
|
||||
|
||||
auto buffer = buffers[pos];
|
||||
decrypted_tensor_1_t decryptedValues(buffer.length() / lweSize);
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
|
||||
|
||||
std::vector<uint64_t> decryptedValues(buffer.length() / lweSize);
|
||||
for (size_t i = 0; i < decryptedValues.size(); i++) {
|
||||
auto ciphertext = &buffer.values[i * lweSize];
|
||||
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decryptedValues[i]));
|
||||
@@ -150,5 +154,63 @@ PublicResult::decryptVector(KeySet &keySet, size_t pos) {
|
||||
return decryptedValues;
|
||||
}
|
||||
|
||||
void next_coord_index(size_t index[], size_t sizes[], size_t rank) {
|
||||
// increase multi dim index
|
||||
for (int r = rank - 1; r >= 0; r--) {
|
||||
if (index[r] < sizes[r] - 1) {
|
||||
index[r]++;
|
||||
return;
|
||||
}
|
||||
index[r] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t global_index(size_t index[], size_t sizes[], size_t strides[],
|
||||
size_t rank) {
|
||||
// compute global index from multi dim index
|
||||
size_t g_index = 0;
|
||||
size_t default_stride = 1;
|
||||
for (int r = rank - 1; r >= 0; r--) {
|
||||
g_index += index[r] * ((strides[r] == 0) ? default_stride : strides[r]);
|
||||
default_stride *= sizes[r];
|
||||
}
|
||||
return g_index;
|
||||
}
|
||||
|
||||
TensorData tensorDataFromScalar(uint64_t value) { return {{value}, {1}}; }
|
||||
|
||||
TensorData tensorDataFromMemRef(size_t memref_rank,
|
||||
encrypted_scalars_t allocated,
|
||||
encrypted_scalars_t aligned, size_t offset,
|
||||
size_t *sizes, size_t *strides) {
|
||||
TensorData result;
|
||||
assert(aligned != nullptr);
|
||||
result.sizes.resize(memref_rank);
|
||||
for (size_t r = 0; r < memref_rank; r++) {
|
||||
result.sizes[r] = sizes[r];
|
||||
}
|
||||
// ephemeral multi dim index to compute global strides
|
||||
size_t *index = new size_t[memref_rank];
|
||||
for (size_t r = 0; r < memref_rank; r++) {
|
||||
index[r] = 0;
|
||||
}
|
||||
auto len = result.length();
|
||||
result.values.resize(len);
|
||||
// TODO: add a fast path for dense result (no real strides)
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
int g_index = offset + global_index(index, sizes, strides, memref_rank);
|
||||
result.values[i] = aligned[offset + g_index];
|
||||
next_coord_index(index, sizes, memref_rank);
|
||||
}
|
||||
delete[] index;
|
||||
// TEMPORARY: That quick and dirty but as this function is used only to
|
||||
// convert a result of the mlir program and as data are copied here, we
|
||||
// release the alocated pointer if it set.
|
||||
if (allocated != nullptr) {
|
||||
free(allocated);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -93,7 +93,7 @@ std::ostream &serializeTensorData(uint64_t *values, size_t length,
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &serializeTensorData(std::vector<size_t> &sizes, uint64_t *values,
|
||||
std::ostream &serializeTensorData(std::vector<int64_t> &sizes, uint64_t *values,
|
||||
std::ostream &ostream) {
|
||||
size_t length = 1;
|
||||
for (auto size : sizes) {
|
||||
@@ -107,7 +107,7 @@ std::ostream &serializeTensorData(std::vector<size_t> &sizes, uint64_t *values,
|
||||
|
||||
std::ostream &serializeTensorData(TensorData &values_and_sizes,
|
||||
std::ostream &ostream) {
|
||||
std::vector<size_t> &sizes = values_and_sizes.sizes;
|
||||
std::vector<int64_t> &sizes = values_and_sizes.sizes;
|
||||
encrypted_scalars_t values = values_and_sizes.values.data();
|
||||
return serializeTensorData(sizes, values, ostream);
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
|
||||
std::vector<void *> args,
|
||||
size_t rank) {
|
||||
using concretelang::clientlib::MemRefDescriptor;
|
||||
constexpr auto convert = &TensorData_from_MemRef;
|
||||
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
|
||||
switch (rank) {
|
||||
case 0: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<1>(*)(void *...))func, args);
|
||||
|
||||
@@ -24,64 +24,6 @@ using concretelang::clientlib::CircuitGateShape;
|
||||
using concretelang::clientlib::PublicArguments;
|
||||
using concretelang::error::StringError;
|
||||
|
||||
void next_coord_index(size_t index[], size_t sizes[], size_t rank) {
|
||||
// increase multi dim index
|
||||
for (int r = rank - 1; r >= 0; r--) {
|
||||
if (index[r] < sizes[r] - 1) {
|
||||
index[r]++;
|
||||
return;
|
||||
}
|
||||
index[r] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t global_index(size_t index[], size_t sizes[], size_t strides[],
|
||||
size_t rank) {
|
||||
// compute global index from multi dim index
|
||||
size_t g_index = 0;
|
||||
size_t default_stride = 1;
|
||||
for (int r = rank - 1; r >= 0; r--) {
|
||||
g_index += index[r] * ((strides[r] == 0) ? default_stride : strides[r]);
|
||||
default_stride *= sizes[r];
|
||||
}
|
||||
return g_index;
|
||||
}
|
||||
|
||||
/** Helper function to convert from MemRefDescriptor to
|
||||
* TensorData assuming MemRefDescriptor are bufferized */
|
||||
TensorData TensorData_from_MemRef(size_t memref_rank,
|
||||
encrypted_scalars_t allocated,
|
||||
encrypted_scalars_t aligned, size_t offset,
|
||||
size_t *sizes, size_t *strides) {
|
||||
TensorData result;
|
||||
assert(aligned != nullptr);
|
||||
result.sizes.resize(memref_rank);
|
||||
for (size_t r = 0; r < memref_rank; r++) {
|
||||
result.sizes[r] = sizes[r];
|
||||
}
|
||||
size_t *index = new size_t[memref_rank]; // ephemeral multi dim index to
|
||||
// compute global strides
|
||||
for (size_t r = 0; r < memref_rank; r++) {
|
||||
index[r] = 0;
|
||||
}
|
||||
auto len = result.length();
|
||||
result.values.resize(len);
|
||||
// TODO: add a fast path for dense result (no real strides)
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
int g_index = offset + global_index(index, sizes, strides, memref_rank);
|
||||
result.values[i] = aligned[offset + g_index];
|
||||
next_coord_index(index, sizes, memref_rank);
|
||||
}
|
||||
delete[] index;
|
||||
// TEMPORARY: That quick and dirty but as this function is used only to
|
||||
// convert a result of the mlir program and as data are copied here, we
|
||||
// release the alocated pointer if it set.
|
||||
if (allocated != nullptr) {
|
||||
free(allocated);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
outcome::checked<ServerLambda, StringError>
|
||||
ServerLambda::loadFromModule(std::shared_ptr<DynamicModule> module,
|
||||
std::string funcName) {
|
||||
|
||||
@@ -69,22 +69,88 @@ llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
|
||||
<< pos << " is null or missing";
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::invoke(Argument &args) {
|
||||
size_t expectedInputs = this->type.getNumParams();
|
||||
size_t actualInputs = args.inputs.size();
|
||||
if (expectedInputs == actualInputs) {
|
||||
return invokeRaw(args.rawArg);
|
||||
}
|
||||
return StreamStringError("invokeRaw: received ")
|
||||
<< actualInputs << "arguments instead of " << expectedInputs;
|
||||
}
|
||||
|
||||
// memref is a struct which is flattened aligned, allocated pointers, offset,
|
||||
// and two array of rank size for sizes and strides.
|
||||
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) {
|
||||
return 3 + 2 * rank;
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
JITLambda::call(clientlib::PublicArguments &args) {
|
||||
// invokeRaw needs to have pointers on arguments and a pointers on the result
|
||||
// as last argument.
|
||||
// Prepare the outputs vector to store the output value of the lambda.
|
||||
auto numOutputs = 0;
|
||||
for (auto &output : args.clientParameters.outputs) {
|
||||
if (output.shape.dimensions.empty()) {
|
||||
// scalar gate
|
||||
if (output.encryption.hasValue()) {
|
||||
// encrypted scalar : memref<lweSizexi64>
|
||||
numOutputs += numArgOfRankedMemrefCallingConvention(1);
|
||||
} else {
|
||||
// clear scalar
|
||||
numOutputs += 1;
|
||||
}
|
||||
} else {
|
||||
// memref gate : rank+1 if the output is encrypted for the lwe size
|
||||
// dimension
|
||||
auto rank = output.shape.dimensions.size() +
|
||||
(output.encryption.hasValue() ? 1 : 0);
|
||||
numOutputs += numArgOfRankedMemrefCallingConvention(rank);
|
||||
}
|
||||
}
|
||||
std::vector<void *> outputs(numOutputs);
|
||||
// Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on
|
||||
// inputs and outputs.
|
||||
std::vector<void *> rawArgs(args.preparedArgs.size() + 1 /*runtime context*/ +
|
||||
outputs.size());
|
||||
size_t i = 0;
|
||||
// Pointers on inputs
|
||||
for (auto &arg : args.preparedArgs) {
|
||||
rawArgs[i++] = &arg;
|
||||
}
|
||||
// Pointer on runtime context, the rawArgs take pointer on actual value that
|
||||
// is passed to the compiled function.
|
||||
auto rtCtxPtr = &args.runtimeContext;
|
||||
rawArgs[i++] = &rtCtxPtr;
|
||||
// Pointers on outputs
|
||||
for (auto &out : outputs) {
|
||||
rawArgs[i++] = &out;
|
||||
}
|
||||
|
||||
// Invoke
|
||||
if (auto err = invokeRaw(rawArgs)) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
// Store the result to the PublicResult
|
||||
std::vector<clientlib::TensorData> buffers;
|
||||
{
|
||||
size_t outputOffset = 0;
|
||||
for (auto &output : args.clientParameters.outputs) {
|
||||
if (output.shape.dimensions.empty() && !output.encryption.hasValue()) {
|
||||
// clear scalar
|
||||
buffers.push_back(
|
||||
clientlib::tensorDataFromScalar((uint64_t)outputs[outputOffset++]));
|
||||
} else {
|
||||
// encrypted scalar, and tensor gate are memref
|
||||
auto rank = output.shape.dimensions.size() +
|
||||
(output.encryption.hasValue() ? 1 : 0);
|
||||
auto allocated = (uint64_t *)outputs[outputOffset++];
|
||||
auto aligned = (uint64_t *)outputs[outputOffset++];
|
||||
auto offset = (size_t)outputs[outputOffset++];
|
||||
size_t *sizes = (size_t *)&outputs[outputOffset];
|
||||
outputOffset += rank;
|
||||
size_t *strides = (size_t *)&outputs[outputOffset];
|
||||
outputOffset += rank;
|
||||
buffers.push_back(clientlib::tensorDataFromMemRef(
|
||||
rank, allocated, aligned, offset, sizes, strides));
|
||||
}
|
||||
}
|
||||
}
|
||||
return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers);
|
||||
}
|
||||
|
||||
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
// Setting the inputs
|
||||
auto numInputs = 0;
|
||||
|
||||
@@ -129,7 +129,8 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
|
||||
|
||||
auto keySet = std::move(keySetOrErr.value());
|
||||
|
||||
return Lambda{this->compilationContext, std::move(lambda), std::move(keySet)};
|
||||
return Lambda{this->compilationContext, std::move(lambda), std::move(keySet),
|
||||
*compRes.clientParameters};
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -38,6 +38,10 @@ compile(std::string outputLib, std::string source,
|
||||
mlir::concretelang::JitCompilerEngine ce{ccx};
|
||||
ce.setClientParametersFuncName(funcname);
|
||||
auto result = ce.compile(sources, outputLib);
|
||||
if (!result) {
|
||||
llvm::errs() << result.takeError();
|
||||
assert(false);
|
||||
}
|
||||
assert(result);
|
||||
return result.get();
|
||||
}
|
||||
@@ -72,7 +76,7 @@ func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
auto maybeKeySet = lambda.keySet(getTestKeySetCachePtr(), 0, 0);
|
||||
ASSERT_TRUE(maybeKeySet.has_value());
|
||||
std::shared_ptr<KeySet> keySet = std::move(maybeKeySet.value());
|
||||
auto maybePublicArguments = lambda.publicArguments(1, keySet);
|
||||
auto maybePublicArguments = lambda.publicArguments(1, *keySet);
|
||||
|
||||
ASSERT_TRUE(maybePublicArguments.has_value());
|
||||
auto publicArguments = std::move(maybePublicArguments.value());
|
||||
@@ -80,7 +84,7 @@ func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
ASSERT_TRUE(publicArguments->serialize(osstream).has_value());
|
||||
EXPECT_TRUE(osstream.good());
|
||||
// Direct call without intermediate
|
||||
EXPECT_TRUE(lambda.serializeCall(1, keySet, osstream));
|
||||
EXPECT_TRUE(lambda.serializeCall(1, *keySet, osstream));
|
||||
EXPECT_TRUE(osstream.good());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user