From f05b1bd1eaa96e5accd467abf20be2cf4fcbb1c9 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 29 Nov 2022 08:58:03 +0100 Subject: [PATCH] feat(rust): support execution with tensor args --- .../concretelang-c/Support/CompilerEngine.h | 23 ++- compiler/lib/Bindings/Rust/src/compiler.rs | 84 +++++++++ compiler/lib/CAPI/Support/CompilerEngine.cpp | 167 +++++++++++++++++- 3 files changed, 265 insertions(+), 9 deletions(-) diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index acab455dc..bf2b640d4 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -12,6 +12,10 @@ extern "C" { #endif +/// The CAPI should be really careful about memory allocation. Every pointer +/// returned should points to a new buffer allocated for the purpose of the +/// CAPI, and should have a respective destructor function. + /// Opaque type declarations. Inspired from /// llvm-project/mlir/include/mlir-c/IR.h /// @@ -204,6 +208,15 @@ MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys); MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value); +MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data, + int64_t *dims, + size_t rank); +MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data, + int64_t *dims, + size_t rank); +MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data, + int64_t *dims, + size_t rank); MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims, size_t rank); @@ -212,11 +225,13 @@ MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg); MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg); MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(LambdaArgument lambdaArg); -MLIR_CAPI_EXPORTED uint64_t * -lambdaArgumentGetTensorData(LambdaArgument lambdaArg); +MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg, + uint64_t *buffer); MLIR_CAPI_EXPORTED size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg); -MLIR_CAPI_EXPORTED int64_t * -lambdaArgumentGetTensorDims(LambdaArgument lambdaArg); +MLIR_CAPI_EXPORTED int64_t +lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg); +MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg, + int64_t *buffer); MLIR_CAPI_EXPORTED PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, size_t argNumber, diff --git a/compiler/lib/Bindings/Rust/src/compiler.rs b/compiler/lib/Bindings/Rust/src/compiler.rs index 5ce184c34..7315d8420 100644 --- a/compiler/lib/Bindings/Rust/src/compiler.rs +++ b/compiler/lib/Bindings/Rust/src/compiler.rs @@ -509,4 +509,88 @@ mod test { assert_eq!(result, 6); } } + + #[test] + fn test_tensor_lambda_argument() { + unsafe { + let mut tensor_data = [1, 2, 3, 73u64]; + let mut tensor_dims = [2, 2i64]; + let tensor_arg = + lambdaArgumentFromTensorU64(tensor_data.as_mut_ptr(), tensor_dims.as_mut_ptr(), 2); + assert!(!lambdaArgumentIsNull(tensor_arg)); + assert!(!lambdaArgumentIsScalar(tensor_arg)); + assert!(lambdaArgumentIsTensor(tensor_arg)); + assert_eq!(lambdaArgumentGetTensorRank(tensor_arg), 2); + assert_eq!(lambdaArgumentGetTensorDataSize(tensor_arg), 4); + let mut dims: [i64; 2] = [0, 0]; + assert_eq!( + lambdaArgumentGetTensorDims(tensor_arg, dims.as_mut_ptr()), + true + ); + assert_eq!(dims, tensor_dims); + + let mut data: [u64; 4] = [0; 4]; + assert_eq!( + lambdaArgumentGetTensorData(tensor_arg, data.as_mut_ptr()), + true + ); + assert_eq!(data, tensor_data); + lambdaArgumentDestroy(tensor_arg); + } + } + + #[test] + fn test_compiler_compile_and_exec_tensor_args() { + unsafe { + let module_to_compile = " + func.func @main(%arg0: tensor<2x3x!FHE.eint<5>>, %arg1: tensor<2x3x!FHE.eint<5>>) -> tensor<2x3x!FHE.eint<5>> { + %0 = \"FHELinalg.add_eint\"(%arg0, %arg1) : (tensor<2x3x!FHE.eint<5>>, tensor<2x3x!FHE.eint<5>>) -> tensor<2x3x!FHE.eint<5>> + return %0 : tensor<2x3x!FHE.eint<5>> + }"; + let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { + Ok(val) => val + "/lib/libConcretelangRuntime.so", + Err(_e) => "".to_string(), + }; + let temp_dir = TempDir::new("rust_test_compiler_compile_and_exec_tensor_args").unwrap(); + let lib_support = LibrarySupport::new( + temp_dir.path().to_str().unwrap(), + runtime_library_path.as_str(), + ) + .unwrap(); + // compile + let result = lib_support.compile(module_to_compile, None).unwrap(); + // loading materials from compilation + // - server_lambda: used for execution + // - client_parameters: used for keygen, encryption, and evaluation keys + let server_lambda = lib_support.load_server_lambda(result).unwrap(); + let client_params = lib_support.load_client_parameters(result).unwrap(); + let client_support = ClientSupport::new(client_params, None).unwrap(); + let key_set = client_support.keyset(None, None).unwrap(); + let eval_keys = keySetGetEvaluationKeys(key_set); + // build lambda arguments from scalar and encrypt them + let args = [ + lambdaArgumentFromTensorU8([1, 2, 3, 4, 5, 6].as_mut_ptr(), [2, 3].as_mut_ptr(), 2), + lambdaArgumentFromTensorU8([1, 4, 7, 4, 2, 9].as_mut_ptr(), [2, 3].as_mut_ptr(), 2), + ]; + let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap(); + // execute the compiled function on the encrypted arguments + let encrypted_result = lib_support + .server_lambda_call(server_lambda, encrypted_args, eval_keys) + .unwrap(); + // decrypt the result of execution + let result_arg = client_support + .decrypt_result(encrypted_result, key_set) + .unwrap(); + // check the tensor dims value from the result lambda argument + assert_eq!(lambdaArgumentGetTensorRank(result_arg), 2); + assert_eq!(lambdaArgumentGetTensorDataSize(result_arg), 6); + let mut dims = [0, 0]; + assert!(lambdaArgumentGetTensorDims(result_arg, dims.as_mut_ptr())); + assert_eq!(dims, [2, 3]); + // check the tensor data from the result lambda argument + let mut data = [0; 6]; + assert!(lambdaArgumentGetTensorData(result_arg, data.as_mut_ptr())); + assert_eq!(data, [2, 6, 10, 8, 7, 15]); + } + } } diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 324eddbd8..5a2ff75cf 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -7,9 +7,11 @@ #include "concretelang/CAPI/Wrappers.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Error.h" +#include "concretelang/Support/LambdaArgument.h" #include "concretelang/Support/LambdaSupport.h" #include "mlir/IR/Diagnostics.h" #include "llvm/Support/SourceMgr.h" +#include #define C_STRUCT_CLEANER(c_struct) \ auto *cpp = unwrap(c_struct); \ @@ -299,8 +301,58 @@ LambdaArgument lambdaArgumentFromScalar(uint64_t value) { return wrap(new mlir::concretelang::IntLambdaArgument(value)); } -// LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims, -// size_t rank); +int64_t getSizeFromRankAndDims(size_t rank, int64_t *dims) { + if (rank == 0) // not a tensor + return 1; + auto size = dims[0]; + for (size_t i = 1; i < rank; i++) + size *= dims[i]; + return size; +} + +LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data, int64_t *dims, + size_t rank) { + + std::vector data_vector(data, + data + getSizeFromRankAndDims(rank, dims)); + std::vector dims_vector(dims, dims + rank); + return wrap(new mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument>(data_vector, + dims_vector)); +} + +LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data, int64_t *dims, + size_t rank) { + + std::vector data_vector(data, + data + getSizeFromRankAndDims(rank, dims)); + std::vector dims_vector(dims, dims + rank); + return wrap(new mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument>(data_vector, + dims_vector)); +} + +LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data, int64_t *dims, + size_t rank) { + + std::vector data_vector(data, + data + getSizeFromRankAndDims(rank, dims)); + std::vector dims_vector(dims, dims + rank); + return wrap(new mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument>(data_vector, + dims_vector)); +} + +LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims, + size_t rank) { + + std::vector data_vector(data, + data + getSizeFromRankAndDims(rank, dims)); + std::vector dims_vector(dims, dims + rank); + return wrap(new mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument>(data_vector, + dims_vector)); +} bool lambdaArgumentIsScalar(LambdaArgument lambdaArg) { return unwrap(lambdaArg) @@ -330,9 +382,114 @@ bool lambdaArgumentIsTensor(LambdaArgument lambdaArg) { mlir::concretelang::IntLambdaArgument>>(); } -// uint64_t *lambdaArgumentGetTensorData(LambdaArgument lambdaArg); -// size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg); -// int64_t *lambdaArgumentGetTensorDims(LambdaArgument lambdaArg); +template +bool copyTensorDataToBuffer( + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> *tensor, + uint64_t *buffer) { + auto *data = tensor->getValue(); + auto sizeOrError = tensor->getNumElements(); + if (!sizeOrError) { + llvm::errs() << llvm::toString(sizeOrError.takeError()); + return false; + } + auto size = sizeOrError.get(); + for (size_t i = 0; i < size; i++) + buffer[i] = data[i]; + return true; +} + +bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg, uint64_t *buffer) { + auto arg = unwrap(lambdaArg); + if (auto tensor = arg->dyn_cast>>()) { + return copyTensorDataToBuffer(tensor, buffer); + } + if (auto tensor = arg->dyn_cast>>()) { + return copyTensorDataToBuffer(tensor, buffer); + } + if (auto tensor = arg->dyn_cast>>()) { + return copyTensorDataToBuffer(tensor, buffer); + } + if (auto tensor = arg->dyn_cast>>()) { + return copyTensorDataToBuffer(tensor, buffer); + } + return false; +} + +size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg) { + auto arg = unwrap(lambdaArg); + if (auto tensor = arg->dyn_cast>>()) { + return tensor->getDimensions().size(); + } + if (auto tensor = arg->dyn_cast>>()) { + return tensor->getDimensions().size(); + } + if (auto tensor = arg->dyn_cast>>()) { + return tensor->getDimensions().size(); + } + if (auto tensor = arg->dyn_cast>>()) { + return tensor->getDimensions().size(); + } + return 0; +} + +int64_t lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg) { + auto arg = unwrap(lambdaArg); + std::vector dims; + if (auto tensor = arg->dyn_cast>>()) { + dims = tensor->getDimensions(); + } else if (auto tensor = + arg->dyn_cast>>()) { + dims = tensor->getDimensions(); + } else if (auto tensor = + arg->dyn_cast>>()) { + dims = tensor->getDimensions(); + } else if (auto tensor = + arg->dyn_cast>>()) { + dims = tensor->getDimensions(); + } else { + return 0; + } + return std::accumulate(std::begin(dims), std::end(dims), 1, + std::multiplies()); +} + +bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg, int64_t *buffer) { + auto arg = unwrap(lambdaArg); + std::vector dims; + if (auto tensor = arg->dyn_cast>>()) { + dims = tensor->getDimensions(); + } else if (auto tensor = + arg->dyn_cast>>()) { + dims = tensor->getDimensions(); + } else if (auto tensor = + arg->dyn_cast>>()) { + dims = tensor->getDimensions(); + } else if (auto tensor = + arg->dyn_cast>>()) { + dims = tensor->getDimensions(); + } else { + return false; + } + memcpy(buffer, dims.data(), sizeof(int64_t) * dims.size()); + return true; +} PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, size_t argNumber, ClientParameters params,