// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt // for license information. #ifndef CONCRETELANG_TESTLIB_ARGUMENTS_H #define CONCRETELANG_TESTLIB_ARGUMENTS_H #include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/ClientLib/KeySet.h" namespace mlir { namespace concretelang { class DynamicLambda; class Arguments { public: Arguments(KeySet &keySet) : currentPos(0), keySet(keySet) { keySet.setRuntimeContext(context); } ~Arguments(); // Create EncryptedArgument that use the given KeySet to perform encryption // and decryption operations. static std::shared_ptr create(KeySet &keySet); // Add a scalar argument. llvm::Error pushArg(uint64_t arg); // Add a vector-tensor argument. llvm::Error pushArg(std::vector arg); template llvm::Error pushArg(std::array arg) { return pushArg(8, (void *)arg.data(), {size}); } // Add a matrix-tensor argument. template llvm::Error pushArg(std::array, size0> arg) { return pushArg(8, (void *)arg.data(), {size0, size1}); } // Add a rank3 tensor. template llvm::Error pushArg( std::array, size1>, size0> arg) { return pushArg(8, (void *)arg.data(), {size0, size1, size2}); } // Generalize by computing shape by template recursion // Set a argument at the given pos as a 1D tensor of T. template llvm::Error pushArg(T *data, int64_t dim1) { return pushArg(data, llvm::ArrayRef(&dim1, 1)); } // Set a argument at the given pos as a tensor of T. template llvm::Error pushArg(T *data, llvm::ArrayRef shape) { return pushArg(8 * sizeof(T), static_cast(data), shape); } llvm::Error pushArg(size_t width, void *data, llvm::ArrayRef shape); // Push the runtime context to the argument list, this must be called // after each argument was pushed. llvm::Error pushContext(); template llvm::Error pushArgs(Arg0 arg0, OtherArgs... others) { auto err = pushArg(arg0); if (err) { return err; } return pushArgs(others...); } llvm::Error pushArgs() { return pushContext(); } private: friend DynamicLambda; template friend llvm::Expected invoke(DynamicLambda &lambda, const Arguments &args); llvm::Error checkPushTooManyArgs(); // Position of the next pushed argument size_t currentPos; std::vector preparedArgs; // Store allocated lwe ciphertexts (for free) std::vector allocatedCiphertexts; // Store buffers of ciphertexts std::vector ciphertextBuffers; KeySet &keySet; RuntimeContext context; }; } // namespace concretelang } // namespace mlir #endif