feat(rust): support execution with tensor args

This commit is contained in:
youben11
2022-11-29 08:58:03 +01:00
committed by Ayoub Benaissa
parent 16f3b0bbf6
commit f05b1bd1ea
3 changed files with 265 additions and 9 deletions

View File

@@ -12,6 +12,10 @@
extern "C" {
#endif
/// The CAPI should be really careful about memory allocation. Every pointer
/// returned should points to a new buffer allocated for the purpose of the
/// CAPI, and should have a respective destructor function.
/// Opaque type declarations. Inspired from
/// llvm-project/mlir/include/mlir-c/IR.h
///
@@ -204,6 +208,15 @@ MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data,
int64_t *dims,
size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data,
int64_t *dims,
size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data,
int64_t *dims,
size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data,
int64_t *dims,
size_t rank);
@@ -212,11 +225,13 @@ MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED uint64_t *
lambdaArgumentGetTensorData(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg,
uint64_t *buffer);
MLIR_CAPI_EXPORTED size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED int64_t *
lambdaArgumentGetTensorDims(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED int64_t
lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg,
int64_t *buffer);
MLIR_CAPI_EXPORTED PublicArguments
lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, size_t argNumber,

View File

@@ -509,4 +509,88 @@ mod test {
assert_eq!(result, 6);
}
}
#[test]
fn test_tensor_lambda_argument() {
unsafe {
let mut tensor_data = [1, 2, 3, 73u64];
let mut tensor_dims = [2, 2i64];
let tensor_arg =
lambdaArgumentFromTensorU64(tensor_data.as_mut_ptr(), tensor_dims.as_mut_ptr(), 2);
assert!(!lambdaArgumentIsNull(tensor_arg));
assert!(!lambdaArgumentIsScalar(tensor_arg));
assert!(lambdaArgumentIsTensor(tensor_arg));
assert_eq!(lambdaArgumentGetTensorRank(tensor_arg), 2);
assert_eq!(lambdaArgumentGetTensorDataSize(tensor_arg), 4);
let mut dims: [i64; 2] = [0, 0];
assert_eq!(
lambdaArgumentGetTensorDims(tensor_arg, dims.as_mut_ptr()),
true
);
assert_eq!(dims, tensor_dims);
let mut data: [u64; 4] = [0; 4];
assert_eq!(
lambdaArgumentGetTensorData(tensor_arg, data.as_mut_ptr()),
true
);
assert_eq!(data, tensor_data);
lambdaArgumentDestroy(tensor_arg);
}
}
#[test]
fn test_compiler_compile_and_exec_tensor_args() {
unsafe {
let module_to_compile = "
func.func @main(%arg0: tensor<2x3x!FHE.eint<5>>, %arg1: tensor<2x3x!FHE.eint<5>>) -> tensor<2x3x!FHE.eint<5>> {
%0 = \"FHELinalg.add_eint\"(%arg0, %arg1) : (tensor<2x3x!FHE.eint<5>>, tensor<2x3x!FHE.eint<5>>) -> tensor<2x3x!FHE.eint<5>>
return %0 : tensor<2x3x!FHE.eint<5>>
}";
let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") {
Ok(val) => val + "/lib/libConcretelangRuntime.so",
Err(_e) => "".to_string(),
};
let temp_dir = TempDir::new("rust_test_compiler_compile_and_exec_tensor_args").unwrap();
let lib_support = LibrarySupport::new(
temp_dir.path().to_str().unwrap(),
runtime_library_path.as_str(),
)
.unwrap();
// compile
let result = lib_support.compile(module_to_compile, None).unwrap();
// loading materials from compilation
// - server_lambda: used for execution
// - client_parameters: used for keygen, encryption, and evaluation keys
let server_lambda = lib_support.load_server_lambda(result).unwrap();
let client_params = lib_support.load_client_parameters(result).unwrap();
let client_support = ClientSupport::new(client_params, None).unwrap();
let key_set = client_support.keyset(None, None).unwrap();
let eval_keys = keySetGetEvaluationKeys(key_set);
// build lambda arguments from scalar and encrypt them
let args = [
lambdaArgumentFromTensorU8([1, 2, 3, 4, 5, 6].as_mut_ptr(), [2, 3].as_mut_ptr(), 2),
lambdaArgumentFromTensorU8([1, 4, 7, 4, 2, 9].as_mut_ptr(), [2, 3].as_mut_ptr(), 2),
];
let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap();
// execute the compiled function on the encrypted arguments
let encrypted_result = lib_support
.server_lambda_call(server_lambda, encrypted_args, eval_keys)
.unwrap();
// decrypt the result of execution
let result_arg = client_support
.decrypt_result(encrypted_result, key_set)
.unwrap();
// check the tensor dims value from the result lambda argument
assert_eq!(lambdaArgumentGetTensorRank(result_arg), 2);
assert_eq!(lambdaArgumentGetTensorDataSize(result_arg), 6);
let mut dims = [0, 0];
assert!(lambdaArgumentGetTensorDims(result_arg, dims.as_mut_ptr()));
assert_eq!(dims, [2, 3]);
// check the tensor data from the result lambda argument
let mut data = [0; 6];
assert!(lambdaArgumentGetTensorData(result_arg, data.as_mut_ptr()));
assert_eq!(data, [2, 6, 10, 8, 7, 15]);
}
}
}

View File

@@ -7,9 +7,11 @@
#include "concretelang/CAPI/Wrappers.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/LambdaArgument.h"
#include "concretelang/Support/LambdaSupport.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/Support/SourceMgr.h"
#include <numeric>
#define C_STRUCT_CLEANER(c_struct) \
auto *cpp = unwrap(c_struct); \
@@ -299,8 +301,58 @@ LambdaArgument lambdaArgumentFromScalar(uint64_t value) {
return wrap(new mlir::concretelang::IntLambdaArgument<uint64_t>(value));
}
// LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims,
// size_t rank);
int64_t getSizeFromRankAndDims(size_t rank, int64_t *dims) {
if (rank == 0) // not a tensor
return 1;
auto size = dims[0];
for (size_t i = 1; i < rank; i++)
size *= dims[i];
return size;
}
LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data, int64_t *dims,
size_t rank) {
std::vector<uint8_t> data_vector(data,
data + getSizeFromRankAndDims(rank, dims));
std::vector<int64_t> dims_vector(dims, dims + rank);
return wrap(new mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>(data_vector,
dims_vector));
}
LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data, int64_t *dims,
size_t rank) {
std::vector<uint16_t> data_vector(data,
data + getSizeFromRankAndDims(rank, dims));
std::vector<int64_t> dims_vector(dims, dims + rank);
return wrap(new mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>(data_vector,
dims_vector));
}
LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data, int64_t *dims,
size_t rank) {
std::vector<uint32_t> data_vector(data,
data + getSizeFromRankAndDims(rank, dims));
std::vector<int64_t> dims_vector(dims, dims + rank);
return wrap(new mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>(data_vector,
dims_vector));
}
LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims,
size_t rank) {
std::vector<uint64_t> data_vector(data,
data + getSizeFromRankAndDims(rank, dims));
std::vector<int64_t> dims_vector(dims, dims + rank);
return wrap(new mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>(data_vector,
dims_vector));
}
bool lambdaArgumentIsScalar(LambdaArgument lambdaArg) {
return unwrap(lambdaArg)
@@ -330,9 +382,114 @@ bool lambdaArgumentIsTensor(LambdaArgument lambdaArg) {
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
}
// uint64_t *lambdaArgumentGetTensorData(LambdaArgument lambdaArg);
// size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
// int64_t *lambdaArgumentGetTensorDims(LambdaArgument lambdaArg);
template <typename T>
bool copyTensorDataToBuffer(
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<T>> *tensor,
uint64_t *buffer) {
auto *data = tensor->getValue();
auto sizeOrError = tensor->getNumElements();
if (!sizeOrError) {
llvm::errs() << llvm::toString(sizeOrError.takeError());
return false;
}
auto size = sizeOrError.get();
for (size_t i = 0; i < size; i++)
buffer[i] = data[i];
return true;
}
bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg, uint64_t *buffer) {
auto arg = unwrap(lambdaArg);
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
return copyTensorDataToBuffer(tensor, buffer);
}
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
return copyTensorDataToBuffer(tensor, buffer);
}
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
return copyTensorDataToBuffer(tensor, buffer);
}
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
return copyTensorDataToBuffer(tensor, buffer);
}
return false;
}
size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg) {
auto arg = unwrap(lambdaArg);
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
return tensor->getDimensions().size();
}
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
return tensor->getDimensions().size();
}
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
return tensor->getDimensions().size();
}
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
return tensor->getDimensions().size();
}
return 0;
}
int64_t lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg) {
auto arg = unwrap(lambdaArg);
std::vector<int64_t> dims;
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
dims = tensor->getDimensions();
} else if (auto tensor =
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
dims = tensor->getDimensions();
} else if (auto tensor =
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
dims = tensor->getDimensions();
} else if (auto tensor =
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
dims = tensor->getDimensions();
} else {
return 0;
}
return std::accumulate(std::begin(dims), std::end(dims), 1,
std::multiplies<int64_t>());
}
bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg, int64_t *buffer) {
auto arg = unwrap(lambdaArg);
std::vector<int64_t> dims;
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
dims = tensor->getDimensions();
} else if (auto tensor =
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
dims = tensor->getDimensions();
} else if (auto tensor =
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
dims = tensor->getDimensions();
} else if (auto tensor =
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
dims = tensor->getDimensions();
} else {
return false;
}
memcpy(buffer, dims.data(), sizeof(int64_t) * dims.size());
return true;
}
PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs,
size_t argNumber, ClientParameters params,