mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(rust): support execution with tensor args
This commit is contained in:
@@ -12,6 +12,10 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/// The CAPI should be really careful about memory allocation. Every pointer
|
||||
/// returned should points to a new buffer allocated for the purpose of the
|
||||
/// CAPI, and should have a respective destructor function.
|
||||
|
||||
/// Opaque type declarations. Inspired from
|
||||
/// llvm-project/mlir/include/mlir-c/IR.h
|
||||
///
|
||||
@@ -204,6 +208,15 @@ MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys);
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value);
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data,
|
||||
int64_t *dims,
|
||||
size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data,
|
||||
int64_t *dims,
|
||||
size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data,
|
||||
int64_t *dims,
|
||||
size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data,
|
||||
int64_t *dims,
|
||||
size_t rank);
|
||||
@@ -212,11 +225,13 @@ MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED uint64_t *
|
||||
lambdaArgumentGetTensorData(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg,
|
||||
uint64_t *buffer);
|
||||
MLIR_CAPI_EXPORTED size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED int64_t *
|
||||
lambdaArgumentGetTensorDims(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED int64_t
|
||||
lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg,
|
||||
int64_t *buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED PublicArguments
|
||||
lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, size_t argNumber,
|
||||
|
||||
@@ -509,4 +509,88 @@ mod test {
|
||||
assert_eq!(result, 6);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_lambda_argument() {
|
||||
unsafe {
|
||||
let mut tensor_data = [1, 2, 3, 73u64];
|
||||
let mut tensor_dims = [2, 2i64];
|
||||
let tensor_arg =
|
||||
lambdaArgumentFromTensorU64(tensor_data.as_mut_ptr(), tensor_dims.as_mut_ptr(), 2);
|
||||
assert!(!lambdaArgumentIsNull(tensor_arg));
|
||||
assert!(!lambdaArgumentIsScalar(tensor_arg));
|
||||
assert!(lambdaArgumentIsTensor(tensor_arg));
|
||||
assert_eq!(lambdaArgumentGetTensorRank(tensor_arg), 2);
|
||||
assert_eq!(lambdaArgumentGetTensorDataSize(tensor_arg), 4);
|
||||
let mut dims: [i64; 2] = [0, 0];
|
||||
assert_eq!(
|
||||
lambdaArgumentGetTensorDims(tensor_arg, dims.as_mut_ptr()),
|
||||
true
|
||||
);
|
||||
assert_eq!(dims, tensor_dims);
|
||||
|
||||
let mut data: [u64; 4] = [0; 4];
|
||||
assert_eq!(
|
||||
lambdaArgumentGetTensorData(tensor_arg, data.as_mut_ptr()),
|
||||
true
|
||||
);
|
||||
assert_eq!(data, tensor_data);
|
||||
lambdaArgumentDestroy(tensor_arg);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compiler_compile_and_exec_tensor_args() {
|
||||
unsafe {
|
||||
let module_to_compile = "
|
||||
func.func @main(%arg0: tensor<2x3x!FHE.eint<5>>, %arg1: tensor<2x3x!FHE.eint<5>>) -> tensor<2x3x!FHE.eint<5>> {
|
||||
%0 = \"FHELinalg.add_eint\"(%arg0, %arg1) : (tensor<2x3x!FHE.eint<5>>, tensor<2x3x!FHE.eint<5>>) -> tensor<2x3x!FHE.eint<5>>
|
||||
return %0 : tensor<2x3x!FHE.eint<5>>
|
||||
}";
|
||||
let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") {
|
||||
Ok(val) => val + "/lib/libConcretelangRuntime.so",
|
||||
Err(_e) => "".to_string(),
|
||||
};
|
||||
let temp_dir = TempDir::new("rust_test_compiler_compile_and_exec_tensor_args").unwrap();
|
||||
let lib_support = LibrarySupport::new(
|
||||
temp_dir.path().to_str().unwrap(),
|
||||
runtime_library_path.as_str(),
|
||||
)
|
||||
.unwrap();
|
||||
// compile
|
||||
let result = lib_support.compile(module_to_compile, None).unwrap();
|
||||
// loading materials from compilation
|
||||
// - server_lambda: used for execution
|
||||
// - client_parameters: used for keygen, encryption, and evaluation keys
|
||||
let server_lambda = lib_support.load_server_lambda(result).unwrap();
|
||||
let client_params = lib_support.load_client_parameters(result).unwrap();
|
||||
let client_support = ClientSupport::new(client_params, None).unwrap();
|
||||
let key_set = client_support.keyset(None, None).unwrap();
|
||||
let eval_keys = keySetGetEvaluationKeys(key_set);
|
||||
// build lambda arguments from scalar and encrypt them
|
||||
let args = [
|
||||
lambdaArgumentFromTensorU8([1, 2, 3, 4, 5, 6].as_mut_ptr(), [2, 3].as_mut_ptr(), 2),
|
||||
lambdaArgumentFromTensorU8([1, 4, 7, 4, 2, 9].as_mut_ptr(), [2, 3].as_mut_ptr(), 2),
|
||||
];
|
||||
let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap();
|
||||
// execute the compiled function on the encrypted arguments
|
||||
let encrypted_result = lib_support
|
||||
.server_lambda_call(server_lambda, encrypted_args, eval_keys)
|
||||
.unwrap();
|
||||
// decrypt the result of execution
|
||||
let result_arg = client_support
|
||||
.decrypt_result(encrypted_result, key_set)
|
||||
.unwrap();
|
||||
// check the tensor dims value from the result lambda argument
|
||||
assert_eq!(lambdaArgumentGetTensorRank(result_arg), 2);
|
||||
assert_eq!(lambdaArgumentGetTensorDataSize(result_arg), 6);
|
||||
let mut dims = [0, 0];
|
||||
assert!(lambdaArgumentGetTensorDims(result_arg, dims.as_mut_ptr()));
|
||||
assert_eq!(dims, [2, 3]);
|
||||
// check the tensor data from the result lambda argument
|
||||
let mut data = [0; 6];
|
||||
assert!(lambdaArgumentGetTensorData(result_arg, data.as_mut_ptr()));
|
||||
assert_eq!(data, [2, 6, 10, 8, 7, 15]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,9 +7,11 @@
|
||||
#include "concretelang/CAPI/Wrappers.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/LambdaArgument.h"
|
||||
#include "concretelang/Support/LambdaSupport.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include <numeric>
|
||||
|
||||
#define C_STRUCT_CLEANER(c_struct) \
|
||||
auto *cpp = unwrap(c_struct); \
|
||||
@@ -299,8 +301,58 @@ LambdaArgument lambdaArgumentFromScalar(uint64_t value) {
|
||||
return wrap(new mlir::concretelang::IntLambdaArgument<uint64_t>(value));
|
||||
}
|
||||
|
||||
// LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims,
|
||||
// size_t rank);
|
||||
int64_t getSizeFromRankAndDims(size_t rank, int64_t *dims) {
|
||||
if (rank == 0) // not a tensor
|
||||
return 1;
|
||||
auto size = dims[0];
|
||||
for (size_t i = 1; i < rank; i++)
|
||||
size *= dims[i];
|
||||
return size;
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data, int64_t *dims,
|
||||
size_t rank) {
|
||||
|
||||
std::vector<uint8_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
std::vector<int64_t> dims_vector(dims, dims + rank);
|
||||
return wrap(new mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>(data_vector,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data, int64_t *dims,
|
||||
size_t rank) {
|
||||
|
||||
std::vector<uint16_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
std::vector<int64_t> dims_vector(dims, dims + rank);
|
||||
return wrap(new mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>(data_vector,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data, int64_t *dims,
|
||||
size_t rank) {
|
||||
|
||||
std::vector<uint32_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
std::vector<int64_t> dims_vector(dims, dims + rank);
|
||||
return wrap(new mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>(data_vector,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims,
|
||||
size_t rank) {
|
||||
|
||||
std::vector<uint64_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
std::vector<int64_t> dims_vector(dims, dims + rank);
|
||||
return wrap(new mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>(data_vector,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
bool lambdaArgumentIsScalar(LambdaArgument lambdaArg) {
|
||||
return unwrap(lambdaArg)
|
||||
@@ -330,9 +382,114 @@ bool lambdaArgumentIsTensor(LambdaArgument lambdaArg) {
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
}
|
||||
|
||||
// uint64_t *lambdaArgumentGetTensorData(LambdaArgument lambdaArg);
|
||||
// size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
|
||||
// int64_t *lambdaArgumentGetTensorDims(LambdaArgument lambdaArg);
|
||||
template <typename T>
|
||||
bool copyTensorDataToBuffer(
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<T>> *tensor,
|
||||
uint64_t *buffer) {
|
||||
auto *data = tensor->getValue();
|
||||
auto sizeOrError = tensor->getNumElements();
|
||||
if (!sizeOrError) {
|
||||
llvm::errs() << llvm::toString(sizeOrError.takeError());
|
||||
return false;
|
||||
}
|
||||
auto size = sizeOrError.get();
|
||||
for (size_t i = 0; i < size; i++)
|
||||
buffer[i] = data[i];
|
||||
return true;
|
||||
}
|
||||
|
||||
bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg, uint64_t *buffer) {
|
||||
auto arg = unwrap(lambdaArg);
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return copyTensorDataToBuffer(tensor, buffer);
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return copyTensorDataToBuffer(tensor, buffer);
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return copyTensorDataToBuffer(tensor, buffer);
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
return copyTensorDataToBuffer(tensor, buffer);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg) {
|
||||
auto arg = unwrap(lambdaArg);
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return tensor->getDimensions().size();
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return tensor->getDimensions().size();
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return tensor->getDimensions().size();
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
return tensor->getDimensions().size();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg) {
|
||||
auto arg = unwrap(lambdaArg);
|
||||
std::vector<int64_t> dims;
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
return std::accumulate(std::begin(dims), std::end(dims), 1,
|
||||
std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg, int64_t *buffer) {
|
||||
auto arg = unwrap(lambdaArg);
|
||||
std::vector<int64_t> dims;
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
memcpy(buffer, dims.data(), sizeof(int64_t) * dims.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs,
|
||||
size_t argNumber, ClientParameters params,
|
||||
|
||||
Reference in New Issue
Block a user