// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt // for license information. #ifndef CONCRETELANG_TESTLIB_DYNAMIC_LAMBDA_H #define CONCRETELANG_TESTLIB_DYNAMIC_LAMBDA_H #include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/ClientLib/KeySet.h" #include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/TestLib/Arguments.h" #include "concretelang/TestLib/DynamicModule.h" namespace mlir { namespace concretelang { template struct MemRefDescriptor; template llvm::Expected invoke(DynamicLambda &lambda, const Arguments &args) { // compile time error if used using COMPATIBLE_RESULT_TYPE = void; return (Result)(COMPATIBLE_RESULT_TYPE)0; // invoke does not accept this kind // of Result } template <> llvm::Expected invoke(DynamicLambda &lambda, const Arguments &args); template <> llvm::Expected> invoke>(DynamicLambda &lambda, const Arguments &args); template <> llvm::Expected>> invoke>>(DynamicLambda &lambda, const Arguments &args); template <> llvm::Expected>>> invoke>>>(DynamicLambda &lambda, const Arguments &args); class DynamicLambda { private: template llvm::Expected> createArguments(Args... args) { if (keySet == nullptr) { return StreamStringError("keySet was not initialized"); } auto arg = Arguments::create(*keySet); auto err = arg->pushArgs(args...); if (err) { return StreamStringError(llvm::toString(std::move(err))); } return arg; } public: static llvm::Expected load(std::string funcName, std::string outputLib); static llvm::Expected load(std::shared_ptr module, std::string funcName); template llvm::Expected call(Args... args) { auto argOrErr = createArguments(args...); if (!argOrErr) { return argOrErr.takeError(); } auto arg = argOrErr.get(); return invoke(*this, *arg); } llvm::Error generateKeySet(llvm::Optional cache = llvm::None, uint64_t seed_msb = 0, uint64_t seed_lsb = 0); protected: template friend llvm::Expected invoke(DynamicLambda &lambda, const Arguments &args); template llvm::Expected> invokeMemRefDecriptor(const Arguments &args); ClientParameters clientParameters; std::shared_ptr keySet; void *(*func)(void *...); // Retain module and open shared lib alive std::shared_ptr module; }; template class TypedDynamicLambda : public DynamicLambda { public: static llvm::Expected> load(std::string funcName, std::string outputLib) { auto lambda = DynamicLambda::load(funcName, outputLib); if (!lambda) { return lambda.takeError(); } return TypedDynamicLambda(*lambda); } llvm::Expected call(Args... args) { return DynamicLambda::call(args...); } // TODO: check parameter types TypedDynamicLambda(DynamicLambda &lambda) : DynamicLambda(lambda) { // TODO: add static check on types vs lambda inputs/outpus } }; } // namespace concretelang } // namespace mlir #endif