diff --git a/compiler/include/concretelang/ClientLib/ClientLambda.h b/compiler/include/concretelang/ClientLib/ClientLambda.h index e83a4acc0..e43dabc43 100644 --- a/compiler/include/concretelang/ClientLib/ClientLambda.h +++ b/compiler/include/concretelang/ClientLib/ClientLambda.h @@ -82,19 +82,18 @@ public: /// ServerLambda::real_call_write function. ostream must be in binary mode /// std::ios_base::openmode::binary outcome::checked - serializeCall(Args... args, std::shared_ptr keySet, - std::ostream &ostream) { + serializeCall(Args... args, KeySet &keySet, std::ostream &ostream) { OUTCOME_TRY(auto publicArguments, publicArguments(args..., keySet)); return publicArguments->serialize(ostream); } outcome::checked, StringError> - publicArguments(Args... args, std::shared_ptr keySet) { + publicArguments(Args... args, KeySet &keySet) { OUTCOME_TRY(auto clientArguments, EncryptedArguments::create(keySet, args...)); return clientArguments->exportPublicArguments(clientParameters, - keySet->runtimeContext()); + keySet.runtimeContext()); } outcome::checked decryptResult(KeySet &keySet, diff --git a/compiler/include/concretelang/ClientLib/ClientParameters.h b/compiler/include/concretelang/ClientLib/ClientParameters.h index e2f293794..f76195a02 100644 --- a/compiler/include/concretelang/ClientLib/ClientParameters.h +++ b/compiler/include/concretelang/ClientLib/ClientParameters.h @@ -120,6 +120,8 @@ static inline bool operator==(const CircuitGateShape &lhs, struct CircuitGate { llvm::Optional encryption; CircuitGateShape shape; + + bool isEncrypted() { return encryption.hasValue(); } }; static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) { return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape; @@ -140,7 +142,32 @@ struct ClientParameters { static std::string getClientParametersPath(std::string path); - LweSecretKeyParam lweSecretKeyParam(CircuitGate gate); + outcome::checked input(size_t pos) { + if (pos >= inputs.size()) { + return StringError("input gate ") << pos << " didn't exists"; + } + return inputs[pos]; + } + + outcome::checked ouput(size_t pos) { + if (pos >= outputs.size()) { + return StringError("output gate ") << pos << " didn't exists"; + } + return outputs[pos]; + } + + outcome::checked + lweSecretKeyParam(CircuitGate gate) { + if (!gate.encryption.hasValue()) { + return StringError("gate is not encrypted"); + } + auto secretKey = secretKeys.find(gate.encryption->secretKeyID); + if (secretKey == secretKeys.end()) { + return StringError("cannot find ") + << gate.encryption->secretKeyID << " in client parameters"; + } + return secretKey->second; + } }; static inline bool operator==(const ClientParameters &lhs, diff --git a/compiler/include/concretelang/ClientLib/EncryptedArguments.h b/compiler/include/concretelang/ClientLib/EncryptedArguments.h index 70e62b5bb..263f665ee 100644 --- a/compiler/include/concretelang/ClientLib/EncryptedArguments.h +++ b/compiler/include/concretelang/ClientLib/EncryptedArguments.h @@ -35,12 +35,16 @@ public: /// an EncryptedArguments template static outcome::checked, StringError> - create(std::shared_ptr keySet, Args... args) { + create(KeySet &keySet, Args... args) { auto arguments = std::make_unique(); OUTCOME_TRYV(arguments->pushArgs(keySet, args...)); return arguments; } + static std::unique_ptr empty() { + return std::make_unique(); + } + /// Export encrypted arguments as public arguments, reset the encrypted /// arguments, i.e. move all buffers to the PublicArguments and reset the /// positional counter. @@ -48,31 +52,44 @@ public: exportPublicArguments(ClientParameters clientParameters, RuntimeContext runtimeContext); -public: - /// Add a uint8_t scalar argument. - outcome::checked pushArg(uint8_t arg, - std::shared_ptr keySet); + /// Check that all arguments as been pushed. + /// TODO: Remove public method here + outcome::checked checkAllArgs(KeySet &keySet); +public: // Add a uint64_t scalar argument. - outcome::checked pushArg(uint64_t arg, - std::shared_ptr keySet); + outcome::checked pushArg(uint64_t arg, KeySet &keySet); /// Add a vector-tensor argument. outcome::checked pushArg(std::vector arg, - std::shared_ptr keySet); + KeySet &keySet); + + // Add a 1D tensor argument with data and size of the dimension. + template + outcome::checked pushArg(const T *data, int64_t dim1, + KeySet &keySet) { + return pushArg(std::vector(data, data + dim1), keySet); + } + + // Add a tensor argument. + template + outcome::checked + pushArg(const T *data, llvm::ArrayRef shape, KeySet &keySet) { + return pushArg(8 * sizeof(T), static_cast(data), shape, + keySet); + } /// Add a 1D tensor argument. template outcome::checked pushArg(std::array arg, - std::shared_ptr keySet) { + KeySet &keySet) { return pushArg(8, (void *)arg.data(), {size}, keySet); } /// Add a 2D tensor argument. template outcome::checked - pushArg(std::array, size0> arg, - std::shared_ptr keySet) { + pushArg(std::array, size0> arg, KeySet &keySet) { return pushArg(8, (void *)arg.data(), {size0, size1}, keySet); } @@ -80,7 +97,7 @@ public: template outcome::checked pushArg(std::array, size1>, size0> arg, - std::shared_ptr keySet) { + KeySet &keySet) { return pushArg(8, (void *)arg.data(), {size0, size1, size2}, keySet); } @@ -88,41 +105,48 @@ public: // Set a argument at the given pos as a 1D tensor of T. template - outcome::checked pushArg(T *data, size_t dim1, - std::shared_ptr keySet) { - return pushArg(data, llvm::ArrayRef(&dim1, 1), keySet); + outcome::checked pushArg(T *data, int64_t dim1, + KeySet &keySet) { + return pushArg(data, llvm::ArrayRef(&dim1, 1), keySet); } // Set a argument at the given pos as a tensor of T. template - outcome::checked pushArg(T *data, - llvm::ArrayRef shape, - std::shared_ptr keySet) { - return pushArg(8 * sizeof(T), static_cast(data), shape, keySet); + outcome::checked + pushArg(T *data, llvm::ArrayRef shape, KeySet &keySet) { + return pushArg(8 * sizeof(T), static_cast(data), shape, + keySet); } - outcome::checked pushArg(size_t width, void *data, + outcome::checked pushArg(size_t width, const void *data, llvm::ArrayRef shape, - std::shared_ptr keySet); + KeySet &keySet); - /// Push a variadic list of arguments. + // Recursive case for scalars: extract first scalar argument from + // parameter pack and forward rest template - outcome::checked pushArgs(std::shared_ptr keySet, - Arg0 arg0, OtherArgs... others) { + outcome::checked pushArgs(KeySet &keySet, Arg0 arg0, + OtherArgs... others) { OUTCOME_TRYV(pushArg(arg0, keySet)); return pushArgs(keySet, others...); } + // Recursive case for tensors: extract pointer and size from + // parameter pack and forward rest + template + outcome::checked + pushArgs(KeySet &keySet, Arg0 *arg0, size_t size, OtherArgs... others) { + OUTCOME_TRYV(pushArg(arg0, size, keySet)); + return pushArgs(keySet, others...); + } + // Terminal case of pushArgs - outcome::checked pushArgs(std::shared_ptr keySet) { + outcome::checked pushArgs(KeySet &keySet) { return checkAllArgs(keySet); } private: - outcome::checked - checkPushTooManyArgs(std::shared_ptr keySet); - outcome::checked - checkAllArgs(std::shared_ptr keySet); + outcome::checked checkPushTooManyArgs(KeySet &keySet); private: // Position of the next pushed argument diff --git a/compiler/include/concretelang/ClientLib/PublicArguments.h b/compiler/include/concretelang/ClientLib/PublicArguments.h index 95d4d644d..8485ccbb3 100644 --- a/compiler/include/concretelang/ClientLib/PublicArguments.h +++ b/compiler/include/concretelang/ClientLib/PublicArguments.h @@ -21,6 +21,11 @@ namespace serverlib { class ServerLambda; } } // namespace concretelang +namespace mlir { +namespace concretelang { +class JITLambda; +} +} // namespace mlir namespace concretelang { namespace clientlib { @@ -45,7 +50,8 @@ public: outcome::checked serialize(std::ostream &ostream); private: - friend class ::concretelang::serverlib::ServerLambda; // from ServerLib + friend class ::concretelang::serverlib::ServerLambda; + friend class ::mlir::concretelang::JITLambda; outcome::checked unserializeArgs(std::istream &istream); @@ -82,16 +88,27 @@ struct PublicResult { /// Serialize into an output stream. outcome::checked serialize(std::ostream &ostream); - /// Decrypt the result at `pos` as a vector. + /// Get the result at `pos` as a vector, if the result is a scalar returns a + /// vector of size 1. Decryption happens if the result is encrypted. outcome::checked, StringError> - decryptVector(KeySet &keySet, size_t pos); + asClearTextVector(KeySet &keySet, size_t pos); -private: + // private: TODO tmp friend class ::concretelang::serverlib::ServerLambda; ClientParameters clientParameters; std::vector buffers; }; +/// Helper function to convert from a scalar to TensorData +TensorData tensorDataFromScalar(uint64_t value); + +/// Helper function to convert from MemRefDescriptor to +/// TensorData +TensorData tensorDataFromMemRef(size_t memref_rank, + encrypted_scalars_t allocated, + encrypted_scalars_t aligned, size_t offset, + size_t *sizes, size_t *strides); + } // namespace clientlib } // namespace concretelang diff --git a/compiler/include/concretelang/ClientLib/Serializers.h b/compiler/include/concretelang/ClientLib/Serializers.h index 73c6cfe1f..d850949d8 100644 --- a/compiler/include/concretelang/ClientLib/Serializers.h +++ b/compiler/include/concretelang/ClientLib/Serializers.h @@ -56,7 +56,7 @@ std::ostream &operator<<(std::ostream &ostream, const RuntimeContext &runtimeContext); std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext); -std::ostream &serializeTensorData(std::vector &sizes, uint64_t *values, +std::ostream &serializeTensorData(std::vector &sizes, uint64_t *values, std::ostream &ostream); std::ostream &serializeTensorData(TensorData &values_and_sizes, diff --git a/compiler/include/concretelang/ClientLib/Types.h b/compiler/include/concretelang/ClientLib/Types.h index 53367d38e..70066eb68 100644 --- a/compiler/include/concretelang/ClientLib/Types.h +++ b/compiler/include/concretelang/ClientLib/Types.h @@ -35,7 +35,7 @@ using encrypted_scalars_t = uint64_t *; struct TensorData { std::vector values; // tensor of rank r + 1 - std::vector sizes; // r sizes + std::vector sizes; // r sizes inline size_t length() { if (sizes.empty()) { diff --git a/compiler/include/concretelang/ServerLib/ServerLambda.h b/compiler/include/concretelang/ServerLib/ServerLambda.h index 14a444a94..77374405d 100644 --- a/compiler/include/concretelang/ServerLib/ServerLambda.h +++ b/compiler/include/concretelang/ServerLib/ServerLambda.h @@ -23,10 +23,6 @@ using concretelang::clientlib::encrypted_scalar_t; using concretelang::clientlib::encrypted_scalars_t; using concretelang::clientlib::TensorData; -TensorData TensorData_from_MemRef(size_t rank, encrypted_scalars_t allocated, - encrypted_scalars_t aligned, size_t offset, - size_t *sizes, size_t *strides); - /// ServerLambda is a utility class that allows to call a function of a /// compilation result. class ServerLambda { diff --git a/compiler/include/concretelang/Support/Jit.h b/compiler/include/concretelang/Support/Jit.h index adf543e9d..2dcfcc44a 100644 --- a/compiler/include/concretelang/Support/Jit.h +++ b/compiler/include/concretelang/Support/Jit.h @@ -11,12 +11,14 @@ #include #include +#include namespace mlir { namespace concretelang { using ::concretelang::clientlib::CircuitGate; using ::concretelang::clientlib::KeySet; +namespace clientlib = ::concretelang::clientlib; /// JITLambda is a tool to JIT compile an mlir module and to invoke a function /// of the module. @@ -118,6 +120,11 @@ public: llvm::function_ref optPipeline, llvm::Optional runtimeLibPath = {}); + /// Call the JIT lambda with the public arguments. + llvm::Expected> + call(clientlib::PublicArguments &args); + +private: /// invokeRaw execute the jit lambda with a list of Argument, the last one is /// used to store the result of the computation. /// Example: @@ -127,9 +134,6 @@ public: /// lambda.invokeRaw(args); llvm::Error invokeRaw(llvm::MutableArrayRef args); - /// invoke the jit lambda with the Argument. - llvm::Error invoke(Argument &args); - private: mlir::LLVM::LLVMFunctionType type; std::string name; diff --git a/compiler/include/concretelang/Support/JitCompilerEngine.h b/compiler/include/concretelang/Support/JitCompilerEngine.h index 88e5ad477..f8cbfe140 100644 --- a/compiler/include/concretelang/Support/JitCompilerEngine.h +++ b/compiler/include/concretelang/Support/JitCompilerEngine.h @@ -17,6 +17,7 @@ namespace mlir { namespace concretelang { using ::concretelang::clientlib::KeySetCache; +namespace clientlib = ::concretelang::clientlib; namespace { // Generic function template as well as specializations of @@ -26,34 +27,35 @@ namespace { // Helper function for `JitCompilerEngine::Lambda::operator()` // implementing type-dependent preparation of the result. template -llvm::Expected typedResult(JITLambda::Argument &arguments); +llvm::Expected typedResult(clientlib::KeySet &keySet, + clientlib::PublicResult &result); // Specialization of `typedResult()` for scalar results, forwarding // scalar value to caller template <> -inline llvm::Expected typedResult(JITLambda::Argument &arguments) { - uint64_t res = 0; - - if (auto err = arguments.getResult(0, res)) - return StreamStringError() << "Cannot retrieve result:" << err; - - return res; +inline llvm::Expected typedResult(clientlib::KeySet &keySet, + clientlib::PublicResult &result) { + auto clearResult = result.asClearTextVector(keySet, 0); + if (!clearResult.has_value()) { + return StreamStringError("typedResult cannot get clear text vector") + << clearResult.error().mesg; + } + if (clearResult.value().size() != 1) { + return StreamStringError("typedResult expect only one value but got ") + << clearResult.value().size(); + } + return clearResult.value()[0]; } template inline llvm::Expected> -typedVectorResult(JITLambda::Argument &arguments) { - llvm::Expected n = arguments.getResultVectorSize(0); - - if (auto err = n.takeError()) - return std::move(err); - - std::vector res(*n); - - if (auto err = arguments.getResult(0, res.data(), res.size())) - return StreamStringError() << "Cannot retrieve result:" << err; - - return std::move(res); +typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { + auto clearResult = result.asClearTextVector(keySet, 0); + if (!clearResult.has_value()) { + return StreamStringError("typedVectorResult cannot get clear text vector") + << clearResult.error().mesg; + } + return std::move(clearResult.value()); } // Specializations of `typedResult()` for vector results, initializing @@ -62,151 +64,144 @@ typedVectorResult(JITLambda::Argument &arguments) { // // Cannot factor out into a template template inline // llvm::Expected> -// typedResult(JITLambda::Argument &arguments); due to ambiguity with -// scalar template -template <> -inline llvm::Expected> -typedResult(JITLambda::Argument &arguments) { - return typedVectorResult(arguments); -} -template <> -inline llvm::Expected> -typedResult(JITLambda::Argument &arguments) { - return typedVectorResult(arguments); -} -template <> -inline llvm::Expected> -typedResult(JITLambda::Argument &arguments) { - return typedVectorResult(arguments); -} +// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due +// to ambiguity with scalar template +// template <> +// inline llvm::Expected> +// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { +// return typedVectorResult(keySet, result); +// } +// template <> +// inline llvm::Expected> +// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { +// return typedVectorResult(keySet, result); +// } +// template <> +// inline llvm::Expected> +// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { +// return typedVectorResult(keySet, result); +// } template <> inline llvm::Expected> -typedResult(JITLambda::Argument &arguments) { - return typedVectorResult(arguments); +typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { + return typedVectorResult(keySet, result); } template llvm::Expected> -buildTensorLambdaResult(JITLambda::Argument &arguments) { +buildTensorLambdaResult(clientlib::KeySet &keySet, + clientlib::PublicResult &result) { llvm::Expected> tensorOrError = - typedResult>(arguments); + typedResult>(keySet, result); - if (!tensorOrError) - return std::move(tensorOrError.takeError()); + if (auto err = tensorOrError.takeError()) + return std::move(err); + std::vector tensorDim(result.buffers[0].sizes.begin(), + result.buffers[0].sizes.end() - 1); - llvm::Expected> tensorDimOrError = - arguments.getResultDimensions(0); - - if (!tensorDimOrError) - return tensorDimOrError.takeError(); - - return std::move(std::make_unique>>( - *tensorOrError, *tensorDimOrError)); + return std::make_unique>>( + *tensorOrError, tensorDim); } // Specialization of `typedResult()` for a single result wrapped into // a `LambdaArgument`. template <> inline llvm::Expected> -typedResult(JITLambda::Argument &arguments) { - llvm::Expected resTy = - arguments.getResultType(0); - - if (!resTy) - return resTy.takeError(); - - switch (*resTy) { - case JITLambda::Argument::ResultType::SCALAR: { - uint64_t res; - - if (llvm::Error err = arguments.getResult(0, res)) - return std::move(err); +typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { + auto gate = keySet.outputGate(0); + // scalar case + if (gate.shape.dimensions.empty()) { + auto clearResult = result.asClearTextVector(keySet, 0); + if (clearResult.has_error()) { + return StreamStringError("typedResult: ") << clearResult.error().mesg; + } + auto res = clearResult.value()[0]; return std::make_unique>(res); } + // tensor case + // auto width = gate.shape.width; - case JITLambda::Argument::ResultType::TENSOR: { - llvm::Expected width = arguments.getResultWidth(0); + // if (width > 32) + return buildTensorLambdaResult(keySet, result); + // else if (width > 16) + // return buildTensorLambdaResult(keySet, result); + // else if (width > 8) + // return buildTensorLambdaResult(keySet, result); + // else if (width <= 8) + // return buildTensorLambdaResult(keySet, result); - if (!width) - return width.takeError(); + // return StreamStringError("Cannot handle scalars with more than 64 bits"); +} - if (*width > 64) - return StreamStringError("Cannot handle scalars with more than 64 bits"); - if (*width > 32) - return buildTensorLambdaResult(arguments); - else if (*width > 16) - return buildTensorLambdaResult(arguments); - else if (*width > 8) - return buildTensorLambdaResult(arguments); - else if (*width <= 8) - return buildTensorLambdaResult(arguments); - } - } - - return StreamStringError("Unknown result type"); } // namespace -// Adaptor class that adds arguments specified as instances of -// `LambdaArgument` to `JitLambda::Argument`. +// Adaptor class that push arguments specified as instances of +// `LambdaArgument` to `clientlib::EncryptedArguments`. class JITLambdaArgumentAdaptor { public: // Checks if the argument `arg` is an plaintext / encrypted integer // argument or a plaintext / encrypted tensor argument with a - // backing integer type `IntT` and adds the argument to `jla` at - // position `pos`. + // backing integer type `IntT` and push the argument to `encryptedArgs`. // // Returns `true` if `arg` has one of the types above and its value - // was successfully added to `jla`, `false` if none of the types + // was successfully added to `encryptedArgs`, `false` if none of the types // matches or an error if a type matched, but adding the argument to - // `jla` failed. + // `encryptedArgs` failed. template static inline llvm::Expected - tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) { + tryAddArg(clientlib::EncryptedArguments &encryptedArgs, + const LambdaArgument &arg, clientlib::KeySet &keySet) { if (auto ila = arg.dyn_cast>()) { - if (llvm::Error err = jla.setArg(pos, ila->getValue())) - return std::move(err); - else + auto res = encryptedArgs.pushArg(ila->getValue(), keySet); + if (!res.has_value()) { + return StreamStringError(res.error().mesg); + } else { return true; + } } else if (auto tla = arg.dyn_cast< TensorLambdaArgument>>()) { - if (llvm::Error err = - jla.setArg(pos, tla->getValue(), tla->getDimensions())) - return std::move(err); - else + auto res = + encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet); + if (!res.has_value()) { + return StreamStringError(res.error().mesg); + } else { return true; + } } - return false; } // Recursive case for `tryAddArg(...)` template static inline llvm::Expected - tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) { - llvm::Expected successOrError = tryAddArg(jla, pos, arg); + tryAddArg(clientlib::EncryptedArguments &encryptedArgs, + const LambdaArgument &arg, clientlib::KeySet &keySet) { + llvm::Expected successOrError = + tryAddArg(encryptedArgs, arg, keySet); if (!successOrError) return successOrError.takeError(); if (successOrError.get() == false) - return tryAddArg(jla, pos, arg); + return tryAddArg(encryptedArgs, arg, keySet); else return true; } - // Attempts to add a single argument `arg` to `jla` at position - // `pos`. Returns an error if either the argument type is - // unsupported or if the argument types is supported, but adding it - // to `jla` failed. - static inline llvm::Error addArgument(JITLambda::Argument &jla, size_t pos, - const LambdaArgument &arg) { + // Attempts to push a single argument `arg` to `encryptedArgs`. Returns an + // error if either the argument type is unsupported or if the argument types + // is supported, but adding it to `encryptedArgs` failed. + static inline llvm::Error + addArgument(clientlib::EncryptedArguments &encryptedArgs, + const LambdaArgument &arg, clientlib::KeySet &keySet) { // Try the supported integer types; size_t needs explicit // treatment, since it may alias none of the fixed size integer // types llvm::Expected successOrError = JITLambdaArgumentAdaptor::tryAddArg(jla, pos, arg); + uint8_t, size_t>(encryptedArgs, arg, + keySet); if (!successOrError) return successOrError.takeError(); @@ -217,7 +212,6 @@ public: return llvm::Error::success(); } }; -} // namespace // A compiler engine that JIT-compiles a source and produces a lambda // object directly invocable through its call operator. @@ -231,12 +225,15 @@ public: Lambda(Lambda &&other) : innerLambda(std::move(other.innerLambda)), keySet(std::move(other.keySet)), - compilationContext(other.compilationContext) {} + compilationContext(other.compilationContext), + clientParameters(other.clientParameters) {} Lambda(std::shared_ptr compilationContext, - std::unique_ptr lambda, std::unique_ptr keySet) + std::unique_ptr lambda, std::unique_ptr keySet, + clientlib::ClientParameters clientParameters) : innerLambda(std::move(lambda)), keySet(std::move(keySet)), - compilationContext(compilationContext) {} + compilationContext(compilationContext), + clientParameters(clientParameters) {} // Returns the number of arguments required for an invocation of // the lambda @@ -251,81 +248,96 @@ public: template llvm::Expected operator()(llvm::ArrayRef lambdaArgs) { - // Create the arguments of the JIT lambda - llvm::Expected> argsOrErr = - mlir::concretelang::JITLambda::Argument::create(*this->keySet.get()); - - if (llvm::Error err = argsOrErr.takeError()) - return StreamStringError("Could not create lambda arguments"); - - // Set the arguments - std::unique_ptr arguments = - std::move(argsOrErr.get()); + // Encrypt the arguments + auto encryptedArgs = clientlib::EncryptedArguments::empty(); for (size_t i = 0; i < lambdaArgs.size(); i++) { if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument( - *arguments, i, *lambdaArgs[i])) { + *encryptedArgs, *lambdaArgs[i], *this->keySet)) { return std::move(err); } } - // Invoke the lambda - if (auto err = this->innerLambda->invoke(*arguments)) - return StreamStringError() << "Cannot invoke lambda:" << err; + auto check = encryptedArgs->checkAllArgs(*this->keySet); + if (check.has_error()) { + return StreamStringError(check.error().mesg); + } - return std::move(typedResult(*arguments)); + // Export as public arguments + auto publicArguments = encryptedArgs->exportPublicArguments( + clientParameters, keySet->runtimeContext()); + if (!publicArguments.has_value()) { + return StreamStringError(publicArguments.error().mesg); + } + + // Call the lambda + auto publicResult = this->innerLambda->call(*publicArguments.value()); + if (auto err = publicResult.takeError()) { + return std::move(err); + } + + return typedResult(*keySet, **publicResult); } // Invocation with an array of arguments of the same type template llvm::Expected operator()(const llvm::ArrayRef args) { - // Create the arguments of the JIT lambda - llvm::Expected> argsOrErr = - mlir::concretelang::JITLambda::Argument::create(*this->keySet.get()); - - if (llvm::Error err = argsOrErr.takeError()) - return StreamStringError("Could not create lambda arguments"); - - // Set the arguments - std::unique_ptr arguments = - std::move(argsOrErr.get()); + // Encrypt the arguments + auto encryptedArgs = clientlib::EncryptedArguments::empty(); for (size_t i = 0; i < args.size(); i++) { - if (auto err = arguments->setArg(i, args[i])) { - return StreamStringError() - << "Cannot push argument " << i << ": " << err; + auto res = encryptedArgs->pushArg(args[i], *keySet); + if (res.has_error()) { + return StreamStringError(res.error().mesg); } } - // Invoke the lambda - if (auto err = this->innerLambda->invoke(*arguments)) - return StreamStringError() << "Cannot invoke lambda:" << err; + auto check = encryptedArgs->checkAllArgs(*this->keySet); + if (check.has_error()) { + return StreamStringError(check.error().mesg); + } - return std::move(typedResult(*arguments)); + // Export as public arguments + auto publicArguments = encryptedArgs->exportPublicArguments( + clientParameters, keySet->runtimeContext()); + if (!publicArguments.has_value()) { + return StreamStringError(publicArguments.error().mesg); + } + + // Call the lambda + auto publicResult = this->innerLambda->call(*publicArguments.value()); + if (auto err = publicResult.takeError()) { + return std::move(err); + } + + return typedResult(*keySet, **publicResult); } // Invocation with arguments of different types template llvm::Expected operator()(const Ts... ts) { - // Create the arguments of the JIT lambda - llvm::Expected> argsOrErr = - mlir::concretelang::JITLambda::Argument::create(*this->keySet.get()); + // Encrypt the arguments + auto encryptedArgs = + clientlib::EncryptedArguments::create(*keySet, ts...); - if (llvm::Error err = argsOrErr.takeError()) - return StreamStringError("Could not create lambda arguments"); + if (encryptedArgs.has_error()) { + return StreamStringError(encryptedArgs.error().mesg); + } - // Set the arguments - std::unique_ptr arguments = - std::move(argsOrErr.get()); + // Export as public arguments + auto publicArguments = encryptedArgs.value()->exportPublicArguments( + clientParameters, keySet->runtimeContext()); + if (!publicArguments.has_value()) { + return StreamStringError(publicArguments.error().mesg); + } - if (llvm::Error err = this->addArgs<0>(arguments.get(), ts...)) + // Call the lambda + auto publicResult = this->innerLambda->call(*publicArguments.value()); + if (auto err = publicResult.takeError()) { return std::move(err); + } - // Invoke the lambda - if (auto err = this->innerLambda->invoke(*arguments)) - return StreamStringError() << "Cannot invoke lambda:" << err; - - return std::move(typedResult(*arguments)); + return typedResult(*keySet, **publicResult); } protected: @@ -364,6 +376,7 @@ public: std::unique_ptr innerLambda; std::unique_ptr keySet; std::shared_ptr compilationContext; + const clientlib::ClientParameters clientParameters; }; JitCompilerEngine(std::shared_ptr compilationContext = diff --git a/compiler/include/concretelang/TestLib/TestTypedLambda.h b/compiler/include/concretelang/TestLib/TestTypedLambda.h index 2a07b5598..2115f2330 100644 --- a/compiler/include/concretelang/TestLib/TestTypedLambda.h +++ b/compiler/include/concretelang/TestLib/TestTypedLambda.h @@ -63,12 +63,29 @@ public: keySet(keySet) {} outcome::checked call(Args... args) { + // std::string message; + + // client stream + // std::ostringstream clientOuput(std::ios::binary); // client argument encryption OUTCOME_TRY(auto encryptedArgs, - clientlib::EncryptedArguments::create(keySet, args...)); + clientlib::EncryptedArguments::create(*keySet, args...)); OUTCOME_TRY(auto publicArgument, encryptedArgs->exportPublicArguments(this->clientParameters, keySet->runtimeContext())); + // client argument serialization + // publicArgument->serialize(clientOuput); + // message = clientOuput.str(); + + // server stream + // std::istringstream serverInput(message, std::ios::binary); + // freeStringMemory(message); + // + // OUTCOME_TRY(auto publicArguments, + // clientlib::PublicArguments::unserialize( + // this->clientParameters, + // serverInput)); + // server function call auto publicResult = serverLambda.call(*publicArgument); diff --git a/compiler/lib/ClientLib/ClientLambda.cpp b/compiler/lib/ClientLib/ClientLambda.cpp index 5d4ace6d9..5907546f5 100644 --- a/compiler/lib/ClientLib/ClientLambda.cpp +++ b/compiler/lib/ClientLib/ClientLambda.cpp @@ -24,11 +24,9 @@ ClientLambda::load(std::string functionName, std::string jsonPath) { return StringError("ClientLambda: cannot find function ") << functionName << " in client parameters" << jsonPath; } - if (param->outputs.size() != 1) { return StringError("ClientLambda: output arity (") - << std::to_string(param->outputs.size()) - << ") != 1 is not supported"; + << std::to_string(param->outputs.size()) << ") != 1 is not supprted"; } if (!param->outputs[0].encryption.hasValue()) { @@ -54,7 +52,7 @@ ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) { outcome::checked, StringError> ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) { - return result.decryptVector(keySet, 0); + return result.asClearTextVector(keySet, 0); } outcome::checked errorResultRank(size_t expected, diff --git a/compiler/lib/ClientLib/ClientParameters.cpp b/compiler/lib/ClientLib/ClientParameters.cpp index 8caf9c4e6..ae59c4833 100644 --- a/compiler/lib/ClientLib/ClientParameters.cpp +++ b/compiler/lib/ClientLib/ClientParameters.cpp @@ -53,10 +53,6 @@ std::size_t ClientParameters::hash() { return currentHash; } -LweSecretKeyParam ClientParameters::lweSecretKeyParam(CircuitGate gate) { - return secretKeys.find(gate.encryption->secretKeyID)->second; -} - llvm::json::Value toJSON(const LweSecretKeyParam &v) { llvm::json::Object object{ {"dimension", v.dimension}, diff --git a/compiler/lib/ClientLib/EncryptedArguments.cpp b/compiler/lib/ClientLib/EncryptedArguments.cpp index 5219ecf1f..db855e6b0 100644 --- a/compiler/lib/ClientLib/EncryptedArguments.cpp +++ b/compiler/lib/ClientLib/EncryptedArguments.cpp @@ -22,34 +22,25 @@ EncryptedArguments::exportPublicArguments(ClientParameters clientParameters, } outcome::checked -EncryptedArguments::pushArg(uint8_t arg, std::shared_ptr keySet) { - return pushArg((uint64_t)arg, keySet); -} - -outcome::checked -EncryptedArguments::pushArg(uint64_t arg, std::shared_ptr keySet) { +EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) { OUTCOME_TRYV(checkPushTooManyArgs(keySet)); auto pos = currentPos++; - CircuitGate input = keySet->inputGate(pos); + CircuitGate input = keySet.inputGate(pos); if (input.shape.size != 0) { return StringError("argument #") << pos << " is not a scalar"; } if (!input.encryption.hasValue()) { // clear scalar: just push the argument - if (input.shape.width != 64) { - return StringError( - "scalar argument of with != 64 is not supported for DynamicLambda"); - } preparedArgs.push_back((void *)arg); return outcome::success(); } ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty TensorData &values_and_sizes = ciphertextBuffers.back(); - auto lweSize = keySet->getInputLweSecretKeyParam(pos).lweSize(); + auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize(); values_and_sizes.sizes.push_back(lweSize); values_and_sizes.values.resize(lweSize); - OUTCOME_TRYV(keySet->encrypt_lwe(pos, values_and_sizes.values.data(), arg)); + OUTCOME_TRYV(keySet.encrypt_lwe(pos, values_and_sizes.values.data(), arg)); // Note: Since we bufferized lwe ciphertext take care of memref calling // convention // allocated @@ -66,18 +57,16 @@ EncryptedArguments::pushArg(uint64_t arg, std::shared_ptr keySet) { } outcome::checked -EncryptedArguments::pushArg(std::vector arg, - std::shared_ptr keySet) { +EncryptedArguments::pushArg(std::vector arg, KeySet &keySet) { return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()}, keySet); } outcome::checked -EncryptedArguments::pushArg(size_t width, void *data, - llvm::ArrayRef shape, - std::shared_ptr keySet) { +EncryptedArguments::pushArg(size_t width, const void *data, + llvm::ArrayRef shape, KeySet &keySet) { OUTCOME_TRYV(checkPushTooManyArgs(keySet)); auto pos = currentPos; - CircuitGate input = keySet->inputGate(pos); + CircuitGate input = keySet.inputGate(pos); // Check the width of data if (input.shape.width > 64) { return StringError("argument #") @@ -108,7 +97,7 @@ EncryptedArguments::pushArg(size_t width, void *data, } } if (input.encryption.hasValue()) { - auto lweSize = keySet->getInputLweSecretKeyParam(pos).lweSize(); + auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize(); values_and_sizes.sizes.push_back(lweSize); // Encrypted tensor: for now we support only 8 bits for encrypted tensor @@ -124,9 +113,14 @@ EncryptedArguments::pushArg(size_t width, void *data, // Allocate ciphertexts and encrypt, for every values in tensor for (size_t i = 0, offset = 0; i < input.shape.size; i++, offset += lweSize) { - OUTCOME_TRYV(keySet->encrypt_lwe(pos, values.data() + offset, data8[i])); + OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data8[i])); } - } // TODO: NON ENCRYPTED, COPY CONTENT TO values_and_sizes + } else { + values_and_sizes.values.resize(input.shape.size); + for (size_t i = 0; i < input.shape.size; i++) { + values_and_sizes.values[i] = ((const uint64_t *)data)[i]; + } + } // allocated preparedArgs.push_back(nullptr); // aligned @@ -150,8 +144,8 @@ EncryptedArguments::pushArg(size_t width, void *data, } outcome::checked -EncryptedArguments::checkPushTooManyArgs(std::shared_ptr keySet) { - size_t arity = keySet->numInputs(); +EncryptedArguments::checkPushTooManyArgs(KeySet &keySet) { + size_t arity = keySet.numInputs(); if (currentPos < arity) { return outcome::success(); } @@ -160,8 +154,8 @@ EncryptedArguments::checkPushTooManyArgs(std::shared_ptr keySet) { } outcome::checked -EncryptedArguments::checkAllArgs(std::shared_ptr keySet) { - size_t arity = keySet->numInputs(); +EncryptedArguments::checkAllArgs(KeySet &keySet) { + size_t arity = keySet.numInputs(); if (currentPos == arity) { return outcome::success(); } diff --git a/compiler/lib/ClientLib/PublicArguments.cpp b/compiler/lib/ClientLib/PublicArguments.cpp index d2db19715..57dcf8ff7 100644 --- a/compiler/lib/ClientLib/PublicArguments.cpp +++ b/compiler/lib/ClientLib/PublicArguments.cpp @@ -19,10 +19,11 @@ namespace clientlib { using concretelang::error::StringError; // TODO: optimize the move -PublicArguments::PublicArguments( - const ClientParameters &clientParameters, RuntimeContext runtimeContext, - bool clearRuntimeContext, std::vector &&preparedArgs_, - std::vector &&ciphertextBuffers_) +PublicArguments::PublicArguments(const ClientParameters &clientParameters, + RuntimeContext runtimeContext, + bool clearRuntimeContext, + std::vector &&preparedArgs_, + std::vector &&ciphertextBuffers_) : clientParameters(clientParameters), runtimeContext(runtimeContext), clearRuntimeContext(clearRuntimeContext) { preparedArgs = std::move(preparedArgs_); @@ -63,7 +64,7 @@ PublicArguments::serialize(std::ostream &ostream) { auto aligned = (encrypted_scalars_t)preparedArgs[iPreparedArgs++]; assert(aligned != nullptr); auto offset = (size_t)preparedArgs[iPreparedArgs++]; - std::vector sizes; // includes lweSize as last dim + std::vector sizes; // includes lweSize as last dim sizes.resize(rank + 1); for (auto dim = 0u; dim < sizes.size(); dim++) { // sizes are part of the client parameters signature @@ -91,7 +92,7 @@ PublicArguments::unserializeArgs(std::istream &istream) { if (!gate.encryption.hasValue()) { return StringError("Clear values are not handled"); } - auto lweSize = clientParameters.lweSecretKeyParam(gate).lweSize(); + auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize(); std::vector sizes = gate.shape.dimensions; sizes.push_back(lweSize); ciphertextBuffers.push_back(unserializeTensorData(sizes, istream)); @@ -135,14 +136,17 @@ PublicArguments::unserialize(ClientParameters &clientParameters, return sArguments; } -outcome::checked, StringError> -PublicResult::decryptVector(KeySet &keySet, size_t pos) { - auto lweSize = - clientParameters.lweSecretKeyParam(clientParameters.outputs[pos]) - .lweSize(); +outcome::checked, StringError> +PublicResult::asClearTextVector(KeySet &keySet, size_t pos) { + OUTCOME_TRY(auto gate, clientParameters.ouput(pos)); + if (!gate.isEncrypted()) { + return buffers[pos].values; + } auto buffer = buffers[pos]; - decrypted_tensor_1_t decryptedValues(buffer.length() / lweSize); + auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize(); + + std::vector decryptedValues(buffer.length() / lweSize); for (size_t i = 0; i < decryptedValues.size(); i++) { auto ciphertext = &buffer.values[i * lweSize]; OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decryptedValues[i])); @@ -150,5 +154,63 @@ PublicResult::decryptVector(KeySet &keySet, size_t pos) { return decryptedValues; } +void next_coord_index(size_t index[], size_t sizes[], size_t rank) { + // increase multi dim index + for (int r = rank - 1; r >= 0; r--) { + if (index[r] < sizes[r] - 1) { + index[r]++; + return; + } + index[r] = 0; + } +} + +size_t global_index(size_t index[], size_t sizes[], size_t strides[], + size_t rank) { + // compute global index from multi dim index + size_t g_index = 0; + size_t default_stride = 1; + for (int r = rank - 1; r >= 0; r--) { + g_index += index[r] * ((strides[r] == 0) ? default_stride : strides[r]); + default_stride *= sizes[r]; + } + return g_index; +} + +TensorData tensorDataFromScalar(uint64_t value) { return {{value}, {1}}; } + +TensorData tensorDataFromMemRef(size_t memref_rank, + encrypted_scalars_t allocated, + encrypted_scalars_t aligned, size_t offset, + size_t *sizes, size_t *strides) { + TensorData result; + assert(aligned != nullptr); + result.sizes.resize(memref_rank); + for (size_t r = 0; r < memref_rank; r++) { + result.sizes[r] = sizes[r]; + } + // ephemeral multi dim index to compute global strides + size_t *index = new size_t[memref_rank]; + for (size_t r = 0; r < memref_rank; r++) { + index[r] = 0; + } + auto len = result.length(); + result.values.resize(len); + // TODO: add a fast path for dense result (no real strides) + for (size_t i = 0; i < len; i++) { + int g_index = offset + global_index(index, sizes, strides, memref_rank); + result.values[i] = aligned[offset + g_index]; + next_coord_index(index, sizes, memref_rank); + } + delete[] index; + // TEMPORARY: That quick and dirty but as this function is used only to + // convert a result of the mlir program and as data are copied here, we + // release the alocated pointer if it set. + if (allocated != nullptr) { + free(allocated); + } + return result; +} + } // namespace clientlib } // namespace concretelang diff --git a/compiler/lib/ClientLib/Serializers.cpp b/compiler/lib/ClientLib/Serializers.cpp index 5f4b83431..c94b70bc2 100644 --- a/compiler/lib/ClientLib/Serializers.cpp +++ b/compiler/lib/ClientLib/Serializers.cpp @@ -93,7 +93,7 @@ std::ostream &serializeTensorData(uint64_t *values, size_t length, return ostream; } -std::ostream &serializeTensorData(std::vector &sizes, uint64_t *values, +std::ostream &serializeTensorData(std::vector &sizes, uint64_t *values, std::ostream &ostream) { size_t length = 1; for (auto size : sizes) { @@ -107,7 +107,7 @@ std::ostream &serializeTensorData(std::vector &sizes, uint64_t *values, std::ostream &serializeTensorData(TensorData &values_and_sizes, std::ostream &ostream) { - std::vector &sizes = values_and_sizes.sizes; + std::vector &sizes = values_and_sizes.sizes; encrypted_scalars_t values = values_and_sizes.values.data(); return serializeTensorData(sizes, values, ostream); } diff --git a/compiler/lib/ServerLib/DynamicRankCall.cpp b/compiler/lib/ServerLib/DynamicRankCall.cpp index 8233a7a39..641f7690d 100644 --- a/compiler/lib/ServerLib/DynamicRankCall.cpp +++ b/compiler/lib/ServerLib/DynamicRankCall.cpp @@ -19,7 +19,7 @@ TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...), std::vector args, size_t rank) { using concretelang::clientlib::MemRefDescriptor; - constexpr auto convert = &TensorData_from_MemRef; + constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef; switch (rank) { case 0: { auto m = multi_arity_call((MemRefDescriptor<1>(*)(void *...))func, args); diff --git a/compiler/lib/ServerLib/ServerLambda.cpp b/compiler/lib/ServerLib/ServerLambda.cpp index e7883a804..b65c12a1d 100644 --- a/compiler/lib/ServerLib/ServerLambda.cpp +++ b/compiler/lib/ServerLib/ServerLambda.cpp @@ -24,64 +24,6 @@ using concretelang::clientlib::CircuitGateShape; using concretelang::clientlib::PublicArguments; using concretelang::error::StringError; -void next_coord_index(size_t index[], size_t sizes[], size_t rank) { - // increase multi dim index - for (int r = rank - 1; r >= 0; r--) { - if (index[r] < sizes[r] - 1) { - index[r]++; - return; - } - index[r] = 0; - } -} - -size_t global_index(size_t index[], size_t sizes[], size_t strides[], - size_t rank) { - // compute global index from multi dim index - size_t g_index = 0; - size_t default_stride = 1; - for (int r = rank - 1; r >= 0; r--) { - g_index += index[r] * ((strides[r] == 0) ? default_stride : strides[r]); - default_stride *= sizes[r]; - } - return g_index; -} - -/** Helper function to convert from MemRefDescriptor to - * TensorData assuming MemRefDescriptor are bufferized */ -TensorData TensorData_from_MemRef(size_t memref_rank, - encrypted_scalars_t allocated, - encrypted_scalars_t aligned, size_t offset, - size_t *sizes, size_t *strides) { - TensorData result; - assert(aligned != nullptr); - result.sizes.resize(memref_rank); - for (size_t r = 0; r < memref_rank; r++) { - result.sizes[r] = sizes[r]; - } - size_t *index = new size_t[memref_rank]; // ephemeral multi dim index to - // compute global strides - for (size_t r = 0; r < memref_rank; r++) { - index[r] = 0; - } - auto len = result.length(); - result.values.resize(len); - // TODO: add a fast path for dense result (no real strides) - for (size_t i = 0; i < len; i++) { - int g_index = offset + global_index(index, sizes, strides, memref_rank); - result.values[i] = aligned[offset + g_index]; - next_coord_index(index, sizes, memref_rank); - } - delete[] index; - // TEMPORARY: That quick and dirty but as this function is used only to - // convert a result of the mlir program and as data are copied here, we - // release the alocated pointer if it set. - if (allocated != nullptr) { - free(allocated); - } - return result; -} - outcome::checked ServerLambda::loadFromModule(std::shared_ptr module, std::string funcName) { diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 2293068a6..31454e8ff 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -69,22 +69,88 @@ llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef args) { << pos << " is null or missing"; } -llvm::Error JITLambda::invoke(Argument &args) { - size_t expectedInputs = this->type.getNumParams(); - size_t actualInputs = args.inputs.size(); - if (expectedInputs == actualInputs) { - return invokeRaw(args.rawArg); - } - return StreamStringError("invokeRaw: received ") - << actualInputs << "arguments instead of " << expectedInputs; -} - // memref is a struct which is flattened aligned, allocated pointers, offset, // and two array of rank size for sizes and strides. uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) { return 3 + 2 * rank; } +llvm::Expected> +JITLambda::call(clientlib::PublicArguments &args) { + // invokeRaw needs to have pointers on arguments and a pointers on the result + // as last argument. + // Prepare the outputs vector to store the output value of the lambda. + auto numOutputs = 0; + for (auto &output : args.clientParameters.outputs) { + if (output.shape.dimensions.empty()) { + // scalar gate + if (output.encryption.hasValue()) { + // encrypted scalar : memref + numOutputs += numArgOfRankedMemrefCallingConvention(1); + } else { + // clear scalar + numOutputs += 1; + } + } else { + // memref gate : rank+1 if the output is encrypted for the lwe size + // dimension + auto rank = output.shape.dimensions.size() + + (output.encryption.hasValue() ? 1 : 0); + numOutputs += numArgOfRankedMemrefCallingConvention(rank); + } + } + std::vector outputs(numOutputs); + // Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on + // inputs and outputs. + std::vector rawArgs(args.preparedArgs.size() + 1 /*runtime context*/ + + outputs.size()); + size_t i = 0; + // Pointers on inputs + for (auto &arg : args.preparedArgs) { + rawArgs[i++] = &arg; + } + // Pointer on runtime context, the rawArgs take pointer on actual value that + // is passed to the compiled function. + auto rtCtxPtr = &args.runtimeContext; + rawArgs[i++] = &rtCtxPtr; + // Pointers on outputs + for (auto &out : outputs) { + rawArgs[i++] = &out; + } + + // Invoke + if (auto err = invokeRaw(rawArgs)) { + return std::move(err); + } + + // Store the result to the PublicResult + std::vector buffers; + { + size_t outputOffset = 0; + for (auto &output : args.clientParameters.outputs) { + if (output.shape.dimensions.empty() && !output.encryption.hasValue()) { + // clear scalar + buffers.push_back( + clientlib::tensorDataFromScalar((uint64_t)outputs[outputOffset++])); + } else { + // encrypted scalar, and tensor gate are memref + auto rank = output.shape.dimensions.size() + + (output.encryption.hasValue() ? 1 : 0); + auto allocated = (uint64_t *)outputs[outputOffset++]; + auto aligned = (uint64_t *)outputs[outputOffset++]; + auto offset = (size_t)outputs[outputOffset++]; + size_t *sizes = (size_t *)&outputs[outputOffset]; + outputOffset += rank; + size_t *strides = (size_t *)&outputs[outputOffset]; + outputOffset += rank; + buffers.push_back(clientlib::tensorDataFromMemRef( + rank, allocated, aligned, offset, sizes, strides)); + } + } + } + return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers); +} + JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { // Setting the inputs auto numInputs = 0; diff --git a/compiler/lib/Support/JitCompilerEngine.cpp b/compiler/lib/Support/JitCompilerEngine.cpp index a0e717a1d..d083c8d94 100644 --- a/compiler/lib/Support/JitCompilerEngine.cpp +++ b/compiler/lib/Support/JitCompilerEngine.cpp @@ -129,7 +129,8 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName, auto keySet = std::move(keySetOrErr.value()); - return Lambda{this->compilationContext, std::move(lambda), std::move(keySet)}; + return Lambda{this->compilationContext, std::move(lambda), std::move(keySet), + *compRes.clientParameters}; } } // namespace concretelang diff --git a/compiler/tests/TestLib/testlib_unit_test.cpp b/compiler/tests/TestLib/testlib_unit_test.cpp index 270786bbd..aed2b2cd0 100644 --- a/compiler/tests/TestLib/testlib_unit_test.cpp +++ b/compiler/tests/TestLib/testlib_unit_test.cpp @@ -38,6 +38,10 @@ compile(std::string outputLib, std::string source, mlir::concretelang::JitCompilerEngine ce{ccx}; ce.setClientParametersFuncName(funcname); auto result = ce.compile(sources, outputLib); + if (!result) { + llvm::errs() << result.takeError(); + assert(false); + } assert(result); return result.get(); } @@ -72,7 +76,7 @@ func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { auto maybeKeySet = lambda.keySet(getTestKeySetCachePtr(), 0, 0); ASSERT_TRUE(maybeKeySet.has_value()); std::shared_ptr keySet = std::move(maybeKeySet.value()); - auto maybePublicArguments = lambda.publicArguments(1, keySet); + auto maybePublicArguments = lambda.publicArguments(1, *keySet); ASSERT_TRUE(maybePublicArguments.has_value()); auto publicArguments = std::move(maybePublicArguments.value()); @@ -80,7 +84,7 @@ func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { ASSERT_TRUE(publicArguments->serialize(osstream).has_value()); EXPECT_TRUE(osstream.good()); // Direct call without intermediate - EXPECT_TRUE(lambda.serializeCall(1, keySet, osstream)); + EXPECT_TRUE(lambda.serializeCall(1, *keySet, osstream)); EXPECT_TRUE(osstream.good()); }